2  Introduction to JAX

JAX_logo

Open In Colab

JAX is a framework that enables high-performance numerical computing by merging Autograd and XLA (Accelerated Linear Algebra). Autograd was originally a library created for automatic differentiation of Python and NumPy code, which has since been added to JAX. XLA is an optimizing compiler for machine learning, which can significantly speed up workloads on both TPU and GPU devices.

2.0.1 Why learn JAX?

2.0.1.1 High performance

JAX has excellent support for accelerators such as GPUs and TPUs and leverages both XLA and Just-In-Time compilation for speedy numerical computation.

2.0.1.2 Automatic differentiation

JAX’s grad transform function renders trivial the calculation of gradients of complex functions for gradient descent, backpropagation and other optimization algorithms. Research and experimentation: The framework’s flexibility grants low-level programming power to developers for effective and rapid prototyping and research. The ease with which JAX code can be run on any device, functional programming and dynamic computation graphs are ideal for experimentation with different machine learning models.

2.0.1.3 Composable and functional

JAX encourages a functional programming style, enabling clean and modular code. Its functions are pure, which means they produce the same output from the same input every time, limiting side effects and improving safety and resusibility. Plays well with others: JAX features interoperability with NumPy, and can be used in conjunction with TensorFlow and PyTorch. Users often combine the features of other libraries with the performance benefits of JAX.

2.0.1.4 How does JAX compare to other frameworks?

2.0.1.5 TensorFlow

TensorFlow creates static computational graphs that define operations and dependencies before execution, enabling efficient optimization and deployment across devices. The tf.Gradient.Tape API supports automatic differentiation, and extensive support for TPUs and GPUs. The framework has a large and mature ecosystem and strong industry adoption. High-level APIs such as Keras and tf.Module make development easier, and TensorFlow Hub offers a repository of pre-trained models.

2.0.1.6 PyTorch

Uses “eager” execution, dynamically building computation graphs as operations are invoked. This flexibility can allow for more intuitive experimentation and debugging. The torch.autograd module enables automatic differentiation and computing gradients during the backward pass. The framework has good support for TPUs via the XLA compiler, and torch.cuda for GPU memory management. PyTorch has a growing ecosystem focused on research and flexibility. Many state-of-the-art models are implemented and shared first using the framework. Its high-level API simplifies building neural networks and includes modules for optimizztion, data handling and visualization.

2.0.1.7 JAX

JAX combines elements of static and dynamic compilation, providing a hybrid approach in which code can be executed “just-in-time”, or traced into static compilation graphs. This enables efficient execution and optimization while retaining flexibility and ease of debugging. The framework offers automatic differentiation through its functional programming model, leveraging function transformations to compute gradients and allow fine-grained control over differentiation.

JAX enjoys excellent GPU and TPU support, integrating tightly with XLA and requiring zero code changes to switch between devices. Spreading data and computation across cores is made simple via its function transformations such as pmap. JAX has a smaller and rapidly-growing community, with popularity among researchers such as Google’s DeepMind. It’s lower-level API is intuitive to those already familiar with NumPy, and neural network libraries Flax and Haiku provide modules and optimizers for training models.

2.0.2 Key concepts

JAX provides an API similar to NumPy that is intuitive and familar to many researchers and engineers. The framework includes composable function transformations for just-in-time compilation, batching, automatic differentiation and parallelization. JAX can be run on TPU, GPU and CPU, without any code changes.

2.0.3 Accelerated NumPy

import jax.numpy as jnp
import numpy as np
from jax import random
from jax import device_put
from jax import grad, vmap, pmap, jit, make_jaxpr
x = jnp.arange(10)
print(x)
[0 1 2 3 4 5 6 7 8 9]

We can move this array from CPU to GPU or TPU.

2.0.4 A word on Colab TPUs

It used to be easy to switch from GPU to TPU in Colabs, however the TPUs set up is now behind JAX version >=0.4, which requires TPU VMs (on GCP).

JAX 0.4 and newer requires TPU VMs, which Colab does not provide at this time. You can still use jax 0.3.25 on Colab TPU, which is the version that comes installed by default on Colab TPU runtimes. If you’ve already updated JAX, you can choose Runtime->Disconnect and Delete Runtime to get a fresh TPU VM, and then skip the pip install step so that you keep the default jax/jaxlib version 0.3.25.

For now, we will proceed with GPUs and run code on Cloud TPU VMs later in the course.

import jax

print(jax.device_count())
device_type = jax.devices()[0].device_kind
device_type
1
'Tesla T4'

More on how JAX creates random numbers later. For now, let’s initialize a pseudo random number generator (PRNG) key.

import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
key, subkey = random.split(key)
x = random.normal(key, (1000, 1000))

print(f"x is of shape: {x.shape}")
print(f"x has dtype: {x.dtype}")
x is of shape: (1000, 1000)
x has dtype: float32
import numpy as np

x = np.array(x)

def x_on_cpu(x):
  return np.dot(x, x)

%timeit -n 1 -r 1 x_on_cpu(x)
29.3 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
def x_on_gpu(x):
  return jnp.dot(x, x)

%timeit -n 5 -r 5 x_on_gpu(x).block_until_ready()
3.03 ms ± 723 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
def numpy_random_state():
  print(str(np.random.get_state())[:100], '...')

numpy_random_state()
('MT19937', array([1483202178, 2954356075, 3069814800,  774374480, 1305506623,
        453414418, 21 ...

2.0.5 Random numbers

Generating random numbers can seem complicated at first glance.

Pseudo random number generation (PRNG) creates sequences that aren’t truly random because they’re determined by their initial value, the seed. Each random sampling is a deterministic functino of a state carried between examples.

In NumPy, PRNG is based on a global state, using the numpy.random module.

def numpy_random_state():
  print(str(np.random.get_state())[:100], '...')

numpy_random_state()
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 40 ...

This state is then updated by each call to random.

np.random.seed(0)

numpy_random_state()

_ = np.random.uniform()

numpy_random_state()
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 40 ...
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  6 ...

JAX handles PRNG differently since the framework intends to be easy to reproduce, parallelize and vectorize.

Rather than use a global state, JAX uses a state called a key.

key = random.PRNGKey(10)

print(key)
[ 0 10]

Random functions consume, and don’t alter, the key. This means the same key should always produce the same sample.

for i in range(0, 3):
  print(random.normal(key))
-1.3445405
-1.3445405
-1.3445405

One practice to bear in mind, then, is never to resuse keys, unless identical outputs are necessary. We can achieve independent keys by using the split() function.

key, subkey = random.split(key)

2.0.6 Intermediate representations

Introducing the Jaxpr

An intermediate representation is an internal interpretation of machine learning code used by underlying frameworks or compilers to optimize the program. When we write code in a framework such as JAX or PyTorch, it is converted from high-level code into a computational graph, or symbolic representation. This is further transformed into an intermediate representation (IR) optimized for efficiency. The IR is then optimized over several passes to conduct operations such as constant folding, operation fusion, model parallelization, and quantization. This is followed by hardware-specific compilation, to convert the IR into low-level code optimized for the required backend (TPU, GPU etc).

The jaxpr

JAX converts functions into an intermediate representation called a jaxpr . Transformations such as grad then work this the jaxpr representation. JAX works by tracing functions. Before we look at what that means, consider this simple Python function:

def sum_squares(x):
    return jnp.sum(x**2)

What does this function do? It looks easy at first glance, since it adds the result of x**2. x could be a single variable, or an array. However, given Python’s dynamism, it could do anything depending on what x is…square an ice-cream order, print job or jackpot winnings in a slot machine.

JAX takes advantage of this dynamism by running functions using tracer values. These are experimental inputs to a function, which help JAX understand how it works and what it will accomplish.

We can take a look at the IR, the jaxpr, by using the jax.make_jaxpr method.

from jax import make_jaxpr

print(make_jaxpr(sum_squares)(3.0))
{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=2] a
    c:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    d:f32[] = reduce_sum[axes=()] c
  in (d,) }

This result comprises the primitive operations, called lax, that JAX knows how to transform.

Herein lies JAX’s power; with no need to make the API complicated, JAX develops a sound idea of what the function is doing, and knows how to vectorize with vmap, parallelize with pmap , and how to just-in-time compile with jit .

2.1 Transformations

Modular, functional programming

JAX can transform functions. This means a numerical function can be returned as a new function that, for example, computes the gradient of, or parallelizes the original function. It could also do both!

2.1.1 grad

One of the most commonly used transformations, jax.grad calculates the gradient of a function.

from jax import grad

def sum_squares(x):
    return jnp.sum(x**2)

Since jax.grad(f) computes the gradient of function f, jax.grad(f)(x) is the gradient of f at x .

print(grad(sum_squares)(3.0))
6.0
print(grad(grad(sum_squares))(3.0))
2.0
import math

def cylinder_volume(r, h):
    vol = jnp.pi * r**2 * h
    return vol

# Compute the volume of a cylinder with radius 3, and height 3
print(cylinder_volume(3, 3))
84.82300164692441
print(grad(cylinder_volume)(4.0, 8.0))
print(grad(cylinder_volume)(2.0, 6.0))
201.06194
75.398224
print(grad(grad(cylinder_volume))(4.0, 8.0))
print(grad(grad(cylinder_volume))(2.0, 6.0))
50.265484
37.699112

We can use argnums to calculate the gradient with respect to different arguments:

def f(x):
  if x > 0:
    return 2 * x ** 3
  else:
    return 3 * x
key = random.PRNGKey(0)
x = random.normal(key, ())
print(key)
print(x)

print(grad(f)(x))
print(grad(f)(-x))
[0 0]
-0.20584226
3.0
0.2542262

An obvious example to make use of grad would be a loss function.

def loss(preds, targets):
  return jnp.sum((preds-targets)**2)

x = jnp.asarray([1.0, 2.0, 3.0, 4.0])
targets = jnp.asarray([1.1, 2.1, 3.1, 4.1])

print(grad(loss)(x, y))
[-4. -2.  0.  2.]

2.1.2 Value and grad

We can return both the value and gradient of a function using value_and_grad. This is a common pattern in machine learning for logging training loss.

from jax import value_and_grad

value_and_grad(loss)(x, y)
(Array(6., dtype=float32), Array([-4., -2.,  0.,  2.], dtype=float32))

2.1.3 jit

The jax.jit() transformation performs Just In Time (JIT) compilation of a JAX Python function for efficient execution in XLA.

Let’s go back to our sum_squares() function and time its original implementation on an array of numbers 1-100.

from jax import jit

def sum_squares(x):
    return jnp.sum(x**2)

x = jnp.arange(100)

%timeit sum_squares(x).block_until_ready()
419 µs ± 41.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Let’s jit the function and notice the speed improvement.

sum_squares_jit = jit(sum_squares)

# Warm up
sum_squares_jit(x).block_until_ready()

%timeit sum_squares_jit(x).block_until_ready()
72.6 µs ± 20.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In case this notation isn’t familiar, µs denotes a ‘microsecond’, or a millionth of a second. ns is a ‘nanosecond’, a billionth of a second. Our jitted function is considerably faster in this simple example. Note: JAX’s asynchronous execution model means the Python call might return before the computation ends. This is why we use the block_until_ready() method to make sure we return the end result. a returned array would not be populated as soon as the function returns. Using block_until_ready means we time the actual computation, not just the dispatch.

2.1.4 Sharp edges

It isn’t possible or economical to jit everything. jit will throw errors when function inputs spark conditional chains (eg if x < 5: … ) and jit itself creates some overhead. jit is best reserved for compiling complex functions that will run several times, such as updating weights in a training loop.

2.1.5 vmap

The jax.vmap transformation generates a vectorized implementation of a function.

Reference for this section (thanks to DeepMind).

We can loop over a batch in Python however such operations tend to be costly.

from jax import vmap

mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
3.62 ms ± 298 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
1.17 ms ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
@jit
def jit_vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('jitted and auto-vectorized with vmap')
%timeit jit_vmap_batched_apply_matrix(batched_x).block_until_ready()
jitted and auto-vectorized with vmap
86.9 µs ± 10 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Putting it all together

We take a loss function, use it to find gradients with grad, vectorize it for work across batches, then jit compile, all in one line.

import jax.numpy as jnp
from jax import grad, vmap, jit

def predict(params, inputs):
    for W, b in params:
        outputs = jnp.dot(inputs, W) + b
        inputs = jnp.tahn(outputs)
    return outputs

def mse_loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    loss = jnp.sum((preds - targets) ** 2)
    print(loss)
    return loss

gradients = jit(grad(mse_loss))
vectorized_gradients = jit(vmap(grad(mse_loss), in_axes=(None, 0)))

2.1.6 pmap

pmap

def pmap_batched_apply_matrix(v_batched):
  return pmap(apply_matrix)(v_batched)
pmap_batched_apply_matrix(batched_x)
NameError: ignored