Back to docs
Recipe

JAX Primer

JAX gives you NumPy-style numerics with autodiff, JIT compilation, and vectorization on accelerators. This primer walks through the four core transforms you will reach for daily when wiring Meridian inference and training loops: jit, grad, vmap, and pmap.

1. Pure functions and tracing

JAX traces your Python function once, then compiles the trace via XLA. Side effects, mutable state, and Python branches over traced values all break the abstraction. Keep functions pure: inputs go in as arrays, outputs come out as arrays, and no globals are touched along the way.

2. JIT compilation

Wrap a hot path in jax.jit to fuse operations and lower them to GPU or TPU kernels. The first call pays the trace cost; every subsequent call with the same input shape and dtype hits the compiled cache.

import jax
import jax.numpy as jnp

@jax.jit
def loss(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

grad_fn = jax.jit(jax.grad(loss))

3. Autodiff and vectorization

jax.grad returns a function that computes gradients with respect to its first argument. Compose it with jax.vmap to batch over an axis without writing explicit loops, then with jax.pmap to shard across devices. The transforms are designed to stack cleanly in any order.