A JAX-inspired automatic differentiation compiler that implements key JAX features like function transformations (jit, vmap, grad) and operation fusion. Built with PyTorch backend and Metal acceleration for Apple Silicon.
- Just-In-Time Compilation (jit): Cache and reuse computation graphs for improved performance
 - Vectorized Mapping (vmap): Automatically vectorize functions across batch dimensions
 - Automatic Differentiation (grad): Compute gradients of functions with respect to inputs
 - Composable Transformations: Stack transformations like 
@jit,@vmap, and@grad 
- Metal Performance Shaders: Optimized for Apple Silicon
 - Automatic Device Placement: Seamlessly handles CPU and GPU operations
 - Operation Fusion: Automatically fuses compatible operations for better performance
 
- Computation Graph: Track operations and dependencies for optimization
 - Automatic Differentiation: Reverse-mode autodiff with efficient gradient computation
 - Operation Fusion: Identify and combine operations for better performance
 
- Python 3.8 or higher
 - pip (Python package installer)
 
There are two ways to install the package:
This method is recommended if you want to modify the code or run examples:
# Clone the repository
git clone https://github.com/codingwithsurya/jax-autodiff.git
cd jax-autodiff
# Install in development mode
pip install -e .If you just want to use the package:
pip install git+https://github.com/codingwithsurya/jax-autodiff.gitAfter installation, you can run any example directly:
# Run complex autodiff example
python3 examples/complex_autodiff.py
# Run other examples
python3 examples/your_example.pyThe project includes a comprehensive test suite that verifies all core functionality:
# Install test dependencies first
pip install -e ".[dev]"
# Run all tests
pytest tests/
# Run specific test files
pytest tests/test_autodiff.py
pytest tests/test_compiler.py
pytest tests/test_transforms.py
# Run tests with verbose output
pytest -v tests/The test suite covers:
- Automatic differentiation (
test_autodiff.py) - Compiler functionality (
test_compiler.py) - Function transformations (
test_transforms.py) 
from src.core.tracer import constant, add, mul
from src.transforms.jit import jit
from src.transforms.vmap import vmap
from src.transforms.grad import grad, value_and_grad
# Define a function
def f(x):
    return add(mul(x, x), constant(1.0))
# JIT compilation
f_fast = jit(f)
result = f_fast(2.0)  # Uses cached computation graph
# Vectorization
batch_f = vmap(f)
batch_result = batch_f([1.0, 2.0, 3.0])  # Applies f to each element
# Gradients
df = grad(f)
gradient = df(2.0)  # Computes df/dx at x=2.0
# Combined transformations
@jit
@vmap
@grad
def optimized_f(x):
    return add(mul(x, x), constant(1.0))def loss(params, data):
    # Your model here
    return prediction_error
value_grad_fn = value_and_grad(loss)
(loss_value, gradients), aux = value_grad_fn(params, data)Operations are automatically fused when possible:
@jit
def fused_ops(x, y):
    a = add(x, y)
    b = mul(a, a)
    return b  # add and mul operations may be fused.
├── examples/          # Example usage and benchmarks
├── src/
│   ├── core/         # Core autodiff and tracing
│   ├── metal/        # Metal acceleration
│   ├── optimizations/# Graph optimizations
│   └── transforms/   # Function transformations
└── tests/            # Unit tests
Feel free to open issues or submit pull requests. Areas of interest:
- Additional function transformations
 - More optimization passes
 - Extended hardware support
 - Performance improvements
 
MIT License
Inspired by the JAX project and its functional programming approach to automatic differentiation.