import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from jax import tree_util, random
3 Linear Regression in JAX
Reference: JAX for the impatient
3.0.1 Pytrees
Before we jump into linear regression in JAX, let’s have a quick look at pytrees. Pytrees are everywhere in JAX and Flax.
JAX treats a pytree as a container of leaf elements. These can include lists, tuples and dicts, so is basically a structure for nested data. Container types do not need to match if nested.
JAX provides the tree_util
package for working with pytrees.
from jax import tree_util
= [1, {"k1": 2, "k2": (3, 4)}, 5]
tree print('tree:', tree)
tree: [1, {'k1': 2, 'k2': (3, 4)}, 5]
tree_util.tree_map(lambda x: x*2, tree): [2, {'k1': 4, 'k2': (6, 8)}, 10]
The tree_map
function is frequently used for updating a tree and its leaves.
lambda x: x*2, tree) tree_util.tree_map(
[2, {'k1': 4, 'k2': (6, 8)}, 10]
We can also provide a tuple of aditional trees of the same shape to the original tree to enable a function to operate on each leaf.
= tree_util.tree_map(lambda x: x*2, tree)
transformed_tree lambda x,y: x+y, tree, transformed_tree) tree_util.tree_map(
[3, {'k1': 6, 'k2': (9, 12)}, 15]
# Linear feed-forward.
def predict(W, b, x):
return jnp.dot(x, W) + b
# Loss function: Mean squared error.
def mse(W, b, x_batched, y_batched):
# Define the squared loss for a single pair (x,y)
def squared_error(x, y):
= predict(W, b, x)
y_pred return jnp.inner(y-y_pred, y-y_pred) / 2.0
# We vectorize the previous to compute the average of the loss on all samples.
return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
# Set problem dimensions.
= 20
n_samples = 10
x_dim = 5
y_dim
# Generate random ground truth W and b.
= random.PRNGKey(0)
key = random.split(key)
k1, k2 = random.normal(k1, (x_dim, y_dim))
W = random.normal(k2, (y_dim,))
b
# Generate samples with additional noise.
= random.split(k1)
key_sample, key_noise = random.normal(key_sample, (n_samples, x_dim))
x_samples = predict(W, b, x_samples) + 0.1 * random.normal(key_noise,(n_samples, y_dim))
y_samples print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
x shape: (20, 10) ; y shape: (20, 5)
In this linear regression, params
is a pytree which contains W
and b
.
# Linear feed-forward that takes a params pytree.
def predict_pytree(params, x):
return jnp.dot(x, params['W']) + params['b']
# Loss function: Mean squared error.
def mse_pytree(params, x_batched,y_batched):
# Define the squared loss for a single pair (x,y)
def squared_error(x,y):
= predict_pytree(params, x)
y_pred return jnp.inner(y-y_pred, y-y_pred) / 2.0
# We vectorize the previous to compute the average of the loss on all samples.
return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
# Initialize estimated W and b with zeros. Store in a pytree.
= {'W': jnp.zeros_like(W), 'b': jnp.zeros_like(b)} params
JAX can differentiate the pytree parameters
jax.grad(mse_pytree)(params, x_samples, y_samples)
{'W': Array([[ 3.02512199e-05, 2.38317996e-04, 5.86672686e-05,
1.45167112e-04, -1.08840875e-04],
[-7.21593387e-05, -5.92094846e-04, -1.44919526e-04,
-3.55741940e-04, 2.12500338e-04],
[ 6.48805872e-06, 5.42663038e-05, 1.30501576e-05,
3.26364461e-05, -1.73687004e-05],
[-1.58965122e-05, -1.67803839e-04, -3.67928296e-05,
-9.91356210e-05, 1.37500465e-05],
[-9.83832870e-05, -7.85302604e-04, -1.94110908e-04,
-4.74306522e-04, 3.20815481e-04],
[-6.23832457e-05, -5.39575703e-04, -1.28836837e-04,
-3.22562526e-04, 1.54131034e-04],
[-8.21873546e-05, -6.98693097e-04, -1.67359249e-04,
-4.20103315e-04, 2.30040867e-04],
[-7.44033605e-06, -5.03805932e-05, -1.21158082e-05,
-3.23755667e-05, 4.07989137e-05],
[ 2.81375833e-06, 5.13494015e-05, 9.00402665e-06,
2.99257226e-05, 1.63719524e-05],
[-2.90344469e-05, -2.14974396e-04, -5.49759716e-05,
-1.30489469e-04, 1.08879060e-04]], dtype=float32),
'b': Array([-3.0185096e-05, -2.2265501e-04, -5.7548052e-05, -1.3502731e-04,
1.1305924e-04], dtype=float32)}
@jax.jit
def update_params_pytree(params, learning_rate, x_samples, y_samples):
= jax.tree_util.tree_map(
params lambda p, g: p - learning_rate * g, params,
jax.grad(mse_pytree)(params, x_samples, y_samples))return params
= 0.3 # Gradient step size.
learning_rate print('Loss for "true" W,b: ', mse_pytree({'W': W, 'b': b}, x_samples, y_samples))
for i in range(101):
# Perform one gradient update.
= update_params_pytree(params, learning_rate, x_samples, y_samples)
params if (i % 5 == 0):
print(f"Loss step {i}: ", mse_pytree(params, x_samples, y_samples))
Loss for "true" W,b: 0.02363979
Loss step 0: 10.97141
Loss step 5: 1.0798324
Loss step 10: 0.3795825
Loss step 15: 0.17855297
Loss step 20: 0.094415195
Loss step 25: 0.054522194
Loss step 30: 0.03448924
Loss step 35: 0.024058029
Loss step 40: 0.018480862
Loss step 45: 0.015438682
Loss step 50: 0.01375394
Loss step 55: 0.0128103
Loss step 60: 0.012277315
Loss step 65: 0.011974388
Loss step 70: 0.011801446
Loss step 75: 0.011702419
Loss step 80: 0.011645543
Loss step 85: 0.011612838
Loss step 90: 0.011594015
Loss step 95: 0.011583163
Loss step 100: 0.011576912
Here we can also use jax.value_and_grad()
to compute both the return value of the input function and its gradient.
# Using jax.value_and_grad instead:
= jax.value_and_grad(mse_pytree)
loss_grad_fn for i in range(101):
# Note that here the loss is computed before the param update.
= loss_grad_fn(params, x_samples, y_samples)
loss_val, grads = jax.tree_util.tree_map(
params lambda p, g: p - learning_rate * g, params, grads)
if (i % 5 == 0):
print(f"Loss step {i}: ", loss_val)
Loss step 0: 0.011576912
Loss step 5: 0.011573299
Loss step 10: 0.011571216
Loss step 15: 0.011570027
Loss step 20: 0.0115693165
Loss step 25: 0.011568918
Loss step 30: 0.011568695
Loss step 35: 0.01156855
Loss step 40: 0.011568478
Loss step 45: 0.011568436
Loss step 50: 0.011568408
Loss step 55: 0.011568391
Loss step 60: 0.01156838
Loss step 65: 0.011568381
Loss step 70: 0.01156838
Loss step 75: 0.011568385
Loss step 80: 0.011568374
Loss step 85: 0.01156838
Loss step 90: 0.01156837
Loss step 95: 0.0115683675
Loss step 100: 0.01156837