import jax.numpy as jnp
import numpy as np
from jax import random
from jax import device_put
from jax import grad, vmap, pmap, jit, make_jaxpr
2 Introduction to JAX
JAX is a framework that enables high-performance numerical computing by merging Autograd and XLA (Accelerated Linear Algebra). Autograd was originally a library created for automatic differentiation of Python and NumPy code, which has since been added to JAX. XLA is an optimizing compiler for machine learning, which can significantly speed up workloads on both TPU and GPU devices.
2.0.1 Why learn JAX?
2.0.1.1 High performance
JAX has excellent support for accelerators such as GPUs and TPUs and leverages both XLA and Just-In-Time compilation for speedy numerical computation.
2.0.1.2 Automatic differentiation
JAX’s grad transform function renders trivial the calculation of gradients of complex functions for gradient descent, backpropagation and other optimization algorithms. Research and experimentation: The framework’s flexibility grants low-level programming power to developers for effective and rapid prototyping and research. The ease with which JAX code can be run on any device, functional programming and dynamic computation graphs are ideal for experimentation with different machine learning models.
2.0.1.3 Composable and functional
JAX encourages a functional programming style, enabling clean and modular code. Its functions are pure, which means they produce the same output from the same input every time, limiting side effects and improving safety and resusibility. Plays well with others: JAX features interoperability with NumPy, and can be used in conjunction with TensorFlow and PyTorch. Users often combine the features of other libraries with the performance benefits of JAX.
2.0.1.4 How does JAX compare to other frameworks?
2.0.1.5 TensorFlow
TensorFlow creates static computational graphs that define operations and dependencies before execution, enabling efficient optimization and deployment across devices. The tf.Gradient.Tape API supports automatic differentiation, and extensive support for TPUs and GPUs. The framework has a large and mature ecosystem and strong industry adoption. High-level APIs such as Keras and tf.Module make development easier, and TensorFlow Hub offers a repository of pre-trained models.
2.0.1.6 PyTorch
Uses “eager” execution, dynamically building computation graphs as operations are invoked. This flexibility can allow for more intuitive experimentation and debugging. The torch.autograd module enables automatic differentiation and computing gradients during the backward pass. The framework has good support for TPUs via the XLA compiler, and torch.cuda for GPU memory management. PyTorch has a growing ecosystem focused on research and flexibility. Many state-of-the-art models are implemented and shared first using the framework. Its high-level API simplifies building neural networks and includes modules for optimizztion, data handling and visualization.
2.0.1.7 JAX
JAX combines elements of static and dynamic compilation, providing a hybrid approach in which code can be executed “just-in-time”, or traced into static compilation graphs. This enables efficient execution and optimization while retaining flexibility and ease of debugging. The framework offers automatic differentiation through its functional programming model, leveraging function transformations to compute gradients and allow fine-grained control over differentiation.
JAX enjoys excellent GPU and TPU support, integrating tightly with XLA and requiring zero code changes to switch between devices. Spreading data and computation across cores is made simple via its function transformations such as pmap. JAX has a smaller and rapidly-growing community, with popularity among researchers such as Google’s DeepMind. It’s lower-level API is intuitive to those already familiar with NumPy, and neural network libraries Flax and Haiku provide modules and optimizers for training models.
2.0.2 Key concepts
JAX provides an API similar to NumPy that is intuitive and familar to many researchers and engineers. The framework includes composable function transformations for just-in-time compilation, batching, automatic differentiation and parallelization. JAX can be run on TPU, GPU and CPU, without any code changes.
2.0.3 Accelerated NumPy
= jnp.arange(10)
x print(x)
[0 1 2 3 4 5 6 7 8 9]
We can move this array from CPU to GPU or TPU.
2.0.4 A word on Colab TPUs
It used to be easy to switch from GPU to TPU in Colabs, however the TPUs set up is now behind JAX version >=0.4, which requires TPU VMs (on GCP).
JAX 0.4 and newer requires TPU VMs, which Colab does not provide at this time. You can still use jax 0.3.25 on Colab TPU, which is the version that comes installed by default on Colab TPU runtimes. If you’ve already updated JAX, you can choose Runtime->Disconnect and Delete Runtime to get a fresh TPU VM, and then skip the pip install step so that you keep the default jax/jaxlib version 0.3.25.
For now, we will proceed with GPUs and run code on Cloud TPU VMs later in the course.
import jax
print(jax.device_count())
= jax.devices()[0].device_kind
device_type device_type
1
'Tesla T4'
More on how JAX creates random numbers later. For now, let’s initialize a pseudo random number generator (PRNG) key.
import jax.numpy as jnp
from jax import random
= random.PRNGKey(0) key
= random.split(key)
key, subkey = random.normal(key, (1000, 1000))
x
print(f"x is of shape: {x.shape}")
print(f"x has dtype: {x.dtype}")
x is of shape: (1000, 1000)
x has dtype: float32
import numpy as np
= np.array(x)
x
def x_on_cpu(x):
return np.dot(x, x)
%timeit -n 1 -r 1 x_on_cpu(x)
29.3 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
def x_on_gpu(x):
return jnp.dot(x, x)
%timeit -n 5 -r 5 x_on_gpu(x).block_until_ready()
3.03 ms ± 723 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
def numpy_random_state():
print(str(np.random.get_state())[:100], '...')
numpy_random_state()
('MT19937', array([1483202178, 2954356075, 3069814800, 774374480, 1305506623,
453414418, 21 ...
2.0.5 Random numbers
Generating random numbers can seem complicated at first glance.
Pseudo random number generation (PRNG) creates sequences that aren’t truly random because they’re determined by their initial value, the seed
. Each random sampling is a deterministic functino of a state
carried between examples.
In NumPy, PRNG is based on a global state, using the numpy.random
module.
def numpy_random_state():
print(str(np.random.get_state())[:100], '...')
numpy_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 40 ...
This state is then updated by each call to random
.
0)
np.random.seed(
numpy_random_state()
= np.random.uniform()
_
numpy_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 40 ...
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 6 ...
JAX handles PRNG differently since the framework intends to be easy to reproduce, parallelize and vectorize.
Rather than use a global state, JAX uses a state called a key
.
= random.PRNGKey(10)
key
print(key)
[ 0 10]
Random functions consume, and don’t alter, the key. This means the same key should always produce the same sample.
for i in range(0, 3):
print(random.normal(key))
-1.3445405
-1.3445405
-1.3445405
One practice to bear in mind, then, is never to resuse keys, unless identical outputs are necessary. We can achieve independent keys by using the split()
function.
= random.split(key) key, subkey
2.0.6 Intermediate representations
Introducing the Jaxpr
An intermediate representation is an internal interpretation of machine learning code used by underlying frameworks or compilers to optimize the program. When we write code in a framework such as JAX or PyTorch, it is converted from high-level code into a computational graph, or symbolic representation. This is further transformed into an intermediate representation (IR) optimized for efficiency. The IR is then optimized over several passes to conduct operations such as constant folding, operation fusion, model parallelization, and quantization. This is followed by hardware-specific compilation, to convert the IR into low-level code optimized for the required backend (TPU, GPU etc).
The jaxpr
JAX converts functions into an intermediate representation called a jaxpr . Transformations such as grad then work this the jaxpr representation. JAX works by tracing functions. Before we look at what that means, consider this simple Python function:
def sum_squares(x):
return jnp.sum(x**2)
What does this function do? It looks easy at first glance, since it adds the result of x**2. x could be a single variable, or an array. However, given Python’s dynamism, it could do anything depending on what x
is…square an ice-cream order, print job or jackpot winnings in a slot machine.
JAX takes advantage of this dynamism by running functions using tracer values. These are experimental inputs to a function, which help JAX understand how it works and what it will accomplish.
We can take a look at the IR, the jaxpr, by using the jax.make_jaxpr
method.
from jax import make_jaxpr
print(make_jaxpr(sum_squares)(3.0))
{ lambda ; a:f32[]. let
b:f32[] = integer_pow[y=2] a
c:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
d:f32[] = reduce_sum[axes=()] c
in (d,) }
This result comprises the primitive operations, called lax, that JAX knows how to transform.
Herein lies JAX’s power; with no need to make the API complicated, JAX develops a sound idea of what the function is doing, and knows how to vectorize with vmap
, parallelize with pmap
, and how to just-in-time compile with jit
.
2.1 Transformations
Modular, functional programming
JAX can transform functions. This means a numerical function can be returned as a new function that, for example, computes the gradient of, or parallelizes the original function. It could also do both!
2.1.1 grad
One of the most commonly used transformations, jax.grad
calculates the gradient of a function.
from jax import grad
def sum_squares(x):
return jnp.sum(x**2)
Since jax.grad(f)
computes the gradient of function f
, jax.grad(f)(x)
is the gradient of f at x .
print(grad(sum_squares)(3.0))
6.0
print(grad(grad(sum_squares))(3.0))
2.0
import math
def cylinder_volume(r, h):
= jnp.pi * r**2 * h
vol return vol
# Compute the volume of a cylinder with radius 3, and height 3
print(cylinder_volume(3, 3))
84.82300164692441
print(grad(cylinder_volume)(4.0, 8.0))
print(grad(cylinder_volume)(2.0, 6.0))
201.06194
75.398224
print(grad(grad(cylinder_volume))(4.0, 8.0))
print(grad(grad(cylinder_volume))(2.0, 6.0))
50.265484
37.699112
We can use argnums to calculate the gradient with respect to different arguments:
def f(x):
if x > 0:
return 2 * x ** 3
else:
return 3 * x
= random.PRNGKey(0)
key = random.normal(key, ())
x print(key)
print(x)
print(grad(f)(x))
print(grad(f)(-x))
[0 0]
-0.20584226
3.0
0.2542262
An obvious example to make use of grad
would be a loss function.
def loss(preds, targets):
return jnp.sum((preds-targets)**2)
= jnp.asarray([1.0, 2.0, 3.0, 4.0])
x = jnp.asarray([1.1, 2.1, 3.1, 4.1])
targets
print(grad(loss)(x, y))
[-4. -2. 0. 2.]
2.1.2 Value and grad
We can return both the value and gradient of a function using value_and_grad
. This is a common pattern in machine learning for logging training loss.
from jax import value_and_grad
value_and_grad(loss)(x, y)
(Array(6., dtype=float32), Array([-4., -2., 0., 2.], dtype=float32))
2.1.3 jit
The jax.jit()
transformation performs Just In Time (JIT) compilation of a JAX Python function for efficient execution in XLA.
Let’s go back to our sum_squares()
function and time its original implementation on an array of numbers 1-100.
from jax import jit
def sum_squares(x):
return jnp.sum(x**2)
= jnp.arange(100)
x
%timeit sum_squares(x).block_until_ready()
419 µs ± 41.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Let’s jit the function and notice the speed improvement.
= jit(sum_squares)
sum_squares_jit
# Warm up
sum_squares_jit(x).block_until_ready()
%timeit sum_squares_jit(x).block_until_ready()
72.6 µs ± 20.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In case this notation isn’t familiar, µs denotes a ‘microsecond’, or a millionth of a second. ns is a ‘nanosecond’, a billionth of a second. Our jitted function is considerably faster in this simple example. Note: JAX’s asynchronous execution model means the Python call might return before the computation ends. This is why we use the block_until_ready() method to make sure we return the end result. a returned array would not be populated as soon as the function returns. Using block_until_ready means we time the actual computation, not just the dispatch.
2.1.5 vmap
The jax.vmap
transformation generates a vectorized implementation of a function.
Reference for this section (thanks to DeepMind).
We can loop over a batch in Python however such operations tend to be costly.
from jax import vmap
= random.normal(key, (150, 100))
mat = random.normal(key, (10, 100))
batched_x
def apply_matrix(v):
return jnp.dot(mat, v)
def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
3.62 ms ± 298 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
def vmap_batched_apply_matrix(v_batched):
return vmap(apply_matrix)(v_batched)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
1.17 ms ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
@jit
def jit_vmap_batched_apply_matrix(v_batched):
return vmap(apply_matrix)(v_batched)
print('jitted and auto-vectorized with vmap')
%timeit jit_vmap_batched_apply_matrix(batched_x).block_until_ready()
jitted and auto-vectorized with vmap
86.9 µs ± 10 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Putting it all together
We take a loss function, use it to find gradients with grad, vectorize it for work across batches, then jit compile, all in one line.
import jax.numpy as jnp
from jax import grad, vmap, jit
def predict(params, inputs):
for W, b in params:
= jnp.dot(inputs, W) + b
outputs = jnp.tahn(outputs)
inputs return outputs
def mse_loss(params, batch):
= batch
inputs, targets = predict(params, inputs)
preds = jnp.sum((preds - targets) ** 2)
loss print(loss)
return loss
= jit(grad(mse_loss))
gradients = jit(vmap(grad(mse_loss), in_axes=(None, 0))) vectorized_gradients
2.1.6 pmap
pmap
def pmap_batched_apply_matrix(v_batched):
return pmap(apply_matrix)(v_batched)
pmap_batched_apply_matrix(batched_x)
NameError: ignored