Abstractions in JAX
Originally posted on LinkedIn.
I’m considering making a transition from TensorFlow to JAX and, so far, am loving how effectively the latter exposes low-level behavior while still providing useful abstractions.
For example, the code snippet1 below shows how one can perform gradient descent while utilising multiple devices:
- the gradients are computed on multiple devices
- they are synced across multiple devices and averaged
- the new parameters are computed by adjusting them in a direction opposite to gradient
@functools.partial(jax.pmap, axis_name="num_devices")
def update(params, xs, ys, learning_rate=0.005):
# 1. Compute the gradients on the given minibatch
# (individually on each device).
grads = jax.grad(loss_fn)(params, xs, ys)
# 2. Combine the gradients across all devices
# (by taking their mean).
grads = jax.lax.pmean(grads, axis_name="num_devices")
# 3. Each device performs its own update, but since we
# start with the same params and synchronise gradients,
# the params stay in sync.
new_params = jax.tree_map(
lambda param, g: param - g * learning_rate,
params,
grads,
)
return new_params
adapted from a tutorial by DeepMind’s Vladimir Mikulik and Roman Ring ↩︎