3  Linear Regression in JAX

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from jax import tree_util, random

Open In Colab

Reference: JAX for the impatient

3.0.1 Pytrees

Before we jump into linear regression in JAX, let’s have a quick look at pytrees. Pytrees are everywhere in JAX and Flax.

JAX treats a pytree as a container of leaf elements. These can include lists, tuples and dicts, so is basically a structure for nested data. Container types do not need to match if nested.

JAX provides the tree_util package for working with pytrees.

from jax import tree_util

tree = [1, {"k1": 2, "k2": (3, 4)}, 5]
print('tree:', tree)
tree: [1, {'k1': 2, 'k2': (3, 4)}, 5]
tree_util.tree_map(lambda x: x*2, tree): [2, {'k1': 4, 'k2': (6, 8)}, 10]

The tree_map function is frequently used for updating a tree and its leaves.

tree_util.tree_map(lambda x: x*2, tree)
[2, {'k1': 4, 'k2': (6, 8)}, 10]

We can also provide a tuple of aditional trees of the same shape to the original tree to enable a function to operate on each leaf.

transformed_tree = tree_util.tree_map(lambda x: x*2, tree)
tree_util.tree_map(lambda x,y: x+y, tree, transformed_tree)
[3, {'k1': 6, 'k2': (9, 12)}, 15]
# Linear feed-forward.
def predict(W, b, x):
  return jnp.dot(x, W) + b

# Loss function: Mean squared error.
def mse(W, b, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    y_pred = predict(W, b, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  # We vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = predict(W, b, x_samples) + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
x shape: (20, 10) ; y shape: (20, 5)

In this linear regression, params is a pytree which contains W and b.

# Linear feed-forward that takes a params pytree.
def predict_pytree(params, x):
  return jnp.dot(x, params['W']) + params['b']

# Loss function: Mean squared error.
def mse_pytree(params, x_batched,y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x,y):
    y_pred = predict_pytree(params, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  # We vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

# Initialize estimated W and b with zeros. Store in a pytree.
params = {'W': jnp.zeros_like(W), 'b': jnp.zeros_like(b)}

JAX can differentiate the pytree parameters

jax.grad(mse_pytree)(params, x_samples, y_samples)
{'W': Array([[ 3.02512199e-05,  2.38317996e-04,  5.86672686e-05,
          1.45167112e-04, -1.08840875e-04],
        [-7.21593387e-05, -5.92094846e-04, -1.44919526e-04,
         -3.55741940e-04,  2.12500338e-04],
        [ 6.48805872e-06,  5.42663038e-05,  1.30501576e-05,
          3.26364461e-05, -1.73687004e-05],
        [-1.58965122e-05, -1.67803839e-04, -3.67928296e-05,
         -9.91356210e-05,  1.37500465e-05],
        [-9.83832870e-05, -7.85302604e-04, -1.94110908e-04,
         -4.74306522e-04,  3.20815481e-04],
        [-6.23832457e-05, -5.39575703e-04, -1.28836837e-04,
         -3.22562526e-04,  1.54131034e-04],
        [-8.21873546e-05, -6.98693097e-04, -1.67359249e-04,
         -4.20103315e-04,  2.30040867e-04],
        [-7.44033605e-06, -5.03805932e-05, -1.21158082e-05,
         -3.23755667e-05,  4.07989137e-05],
        [ 2.81375833e-06,  5.13494015e-05,  9.00402665e-06,
          2.99257226e-05,  1.63719524e-05],
        [-2.90344469e-05, -2.14974396e-04, -5.49759716e-05,
         -1.30489469e-04,  1.08879060e-04]], dtype=float32),
 'b': Array([-3.0185096e-05, -2.2265501e-04, -5.7548052e-05, -1.3502731e-04,
         1.1305924e-04], dtype=float32)}
@jax.jit
def update_params_pytree(params, learning_rate, x_samples, y_samples):
  params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params,
        jax.grad(mse_pytree)(params, x_samples, y_samples))
  return params

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse_pytree({'W': W, 'b': b}, x_samples, y_samples))
for i in range(101):
  # Perform one gradient update.
  params = update_params_pytree(params, learning_rate, x_samples, y_samples)
  if (i % 5 == 0):
    print(f"Loss step {i}: ", mse_pytree(params, x_samples, y_samples))
Loss for "true" W,b:  0.02363979
Loss step 0:  10.97141
Loss step 5:  1.0798324
Loss step 10:  0.3795825
Loss step 15:  0.17855297
Loss step 20:  0.094415195
Loss step 25:  0.054522194
Loss step 30:  0.03448924
Loss step 35:  0.024058029
Loss step 40:  0.018480862
Loss step 45:  0.015438682
Loss step 50:  0.01375394
Loss step 55:  0.0128103
Loss step 60:  0.012277315
Loss step 65:  0.011974388
Loss step 70:  0.011801446
Loss step 75:  0.011702419
Loss step 80:  0.011645543
Loss step 85:  0.011612838
Loss step 90:  0.011594015
Loss step 95:  0.011583163
Loss step 100:  0.011576912

Here we can also use jax.value_and_grad() to compute both the return value of the input function and its gradient.

# Using jax.value_and_grad instead:
loss_grad_fn = jax.value_and_grad(mse_pytree)
for i in range(101):
  # Note that here the loss is computed before the param update.
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, grads)
    if (i % 5 == 0):
        print(f"Loss step {i}: ", loss_val)
Loss step 0:  0.011576912
Loss step 5:  0.011573299
Loss step 10:  0.011571216
Loss step 15:  0.011570027
Loss step 20:  0.0115693165
Loss step 25:  0.011568918
Loss step 30:  0.011568695
Loss step 35:  0.01156855
Loss step 40:  0.011568478
Loss step 45:  0.011568436
Loss step 50:  0.011568408
Loss step 55:  0.011568391
Loss step 60:  0.01156838
Loss step 65:  0.011568381
Loss step 70:  0.01156838
Loss step 75:  0.011568385
Loss step 80:  0.011568374
Loss step 85:  0.01156838
Loss step 90:  0.01156837
Loss step 95:  0.0115683675
Loss step 100:  0.01156837