import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn5 Flax Foundations
Efficient and flexible model development
By combining JAX’s auto-differentiation and Flax’s modular design, developers can easily construct and train state-of-the-art deep learning models. JAX/Flax traces pure functions and compiles for GPU and TPU accelerators.
# Here's a single dense layer that takes a number of features as input
model = nn.Dense(features=5)
key1, key2 = random.split(random.PRNGKey(0))
# Dummy input data
x = random.normal(key1, (10,))
# Initialize the model
params = model.init(key2, x)
# Forward pass
model.apply(params, x)WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([-1.3721193 , 0.61131495, 0.6442836 , 2.2192965 , -1.1271116 ], dtype=float32)
Note we only mention to Flax the number of features for the output of the model, rather than specifying the size of the input. Flax works out the correct kernel size for us!
Let’s take a look at the pytree:
# Check output shapes
jax.tree_util.tree_map(lambda x: x.shape, params)FrozenDict({
params: {
bias: (5,),
kernel: (10, 5),
},
})
Notice the parameters are stored in a FrozenDict, which prevents any mutation of the values.
import jax
import jax.numpy as jnp
import flax.linen as nn
# Dummy data
inputs = jnp.array([[0.2, 0.3, 0.4], [0.1, 0.2, 0.3]])
targets = jnp.array([[0.5], [0.8]])
# Simple feedforward neural network
class SimpleNetwork(nn.Module):
hidden_size: int
output_size: int
def setup(self):
self.dense1 = nn.Dense(self.hidden_size)
self.dense2 = nn.Dense(self.output_size)
def __call__(self, x):
x = self.dense1(x)
x = nn.relu(x)
x = self.dense2(x)
return x
# Initialization
hidden_size = 16
output_size = 1
rng = jax.random.PRNGKey(0)
model = SimpleNetwork(hidden_size, output_size)
params = model.init(rng, inputs)
tree = jax.tree_util.tree_map(lambda inputs: inputs.shape, params) # Checking output shapes
print(tree)
# Forward pass
predictions = model.apply(params, inputs)
print(f"Inputs: \n{inputs}")
print(f"\nPredictions: \n{predictions}")
print(f"\nTarget data: \n{targets}")FrozenDict({
params: {
dense1: {
bias: (16,),
kernel: (3, 16),
},
dense2: {
bias: (1,),
kernel: (16, 1),
},
},
})
Inputs:
[[0.2 0.3 0.4]
[0.1 0.2 0.3]]
Predictions:
[[-0.01026188]
[-0.01458298]]
Target data:
[[0.5]
[0.8]]
In this example, we defined our model explicitly using setup. We can also define architecrures using nn.compact, which allows us to define a modulea s a single method. This can lead to cleaner code if you are writing custom layers.
Here’s our SimpleNetwork again, using setup.
class SimpleNetwork(nn.Module):
hidden_size: int
output_size: int
def setup(self):
self.dense1 = nn.Dense(self.hidden_size)
self.dense2 = nn.Dense(self.output_size)
def __call__(self, x):
x = self.dense1(x)
x = nn.relu(x)
x = self.dense2(x)
return xAnd using nn.compact:
class SimpleNetwork(nn.Module):
hidden_size: int
output_size: int
@nn.compact
def __call__(self, x):
x = nn.Dense(hidden_size, name="dense1")(x)
x = nn.relu(x)
x = nn.Dense(output_size, name="dense2")(x)
return xIf you are porting models from PyTorch, or prefer explicit definition and separation of submodules, setup may suit. nn.compact may be best for reducing duplication, writing code that looks closer to mathematical notation, or if you are using shape inference (parameters dependant on shapes of inputs unknown at initialization).
5.0.1 Flax modules
Flax it easy to incorporate training techniques such as batch normalization and learning rate scheduling via the flax.linen.Module.
Here’s our simple multi-layer perceptron again:
class SimpleNetwork(nn.Module):
hidden_size: int
output_size: int
@nn.compact
def __call__(self, x):
x = nn.Dense(hidden_size, name="dense1")(x)
x = nn.relu(x)
x = nn.Dense(output_size, name="dense2")(x)
return xBatch normalization is a regularization technique which computes running averages over feature dimensions. This speeds up training cycles and improves convergence. To apply batch normalization, we call upon flax.linen.BatchNorm.
class SimpleNetwork(nn.Module):
hidden_size: int
output_size: int
@nn.compact
def __call__(self, x, train: bool):
x = nn.Dense(hidden_size, name="dense1")(x)
x = nn.BatchNorn(use_running_average=not train)(x)
x = nn.relu(x)
x = nn.Dense(output_size, name="dense2")(x)
return x5.0.2 Dropout
Dropout is another (stochastic) regularization technique that randomly removes units in a network to improve reduce overfitting and improve generalization.
Dropout requires our PRNG skills to endure it is a random operation.
When splitting a key, we can simply split into three keys, granting the third for flax.linen.dropout.
key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=key, num=3)Then add the module to our model:
class SimpleNetwork(nn.Module):
hidden_size: int
output_size: int
@nn.compact
def __call__(self, x, train: bool):
x = nn.Dense(hidden_size, name="dense1")(x)
x = nn.Dropout(rate=0.5, deterministic=not train)(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.relu(x)
x = nn.Dense(output_size, name="dense2")(x)
return xWe can then initialize the model:
simple_net = SimpleNetwork(hidden_size=5, output_size=1)
x = jnp.empty((3, 4, 4, 5, 5))
# Dropout is enabled via `deterministic=True`.
variables = simple_net.init(params_key, x, train=False)
params = variables['params']5.0.3 Train states
A “train state” is the mutable state of a model during training, including properties such as its parameters (weights) and optimizer state.
The train state is typically represented as an instance of the flax.training.TrainState class, which encapsulates and provides methods to update the state.
One of the features of JAX/Flax is its functional programming characteristic of immutability. Models are updates are purely functional, enabling model parallelism and efficient training.
# Example, will not run
def create_train_state(rng, learning_rate, momentum):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
5.0.4 Optax
Optax is a gradient processing and optimization package. It is generally used with Flax as follows:
Create an optimizer state from parameters using any optimization method (eg optax.rmsprop). Compute loss gradients using value_and_grad(). Call the Optax update function to update the internal optimizer state to work out how to tweak the parameters. Use apply_updates to apply update the to the parameters.
For example (will not run):
import optax
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)
loss_grad_func = jax.value_and_grad(mse)
for i in range(10):
loss, grads = loss_grad_func(params, x_samples, y_samples)
updates, optimizer_state = optimizer.update(grads, optimizer_state)
params = optax.apply_updates(params, updates)
if i % 10 == 0:
print('Loss step {}: '.format(i), loss)MNIST Example
from absl import logging
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
@jax.jit
def apply_model(state, images, labels):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(params):
logits = state.apply_fn({'params': params}, images)
one_hot = jax.nn.one_hot(labels, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return grads, loss, accuracy
@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)
def train_epoch(state, train_ds, batch_size, rng):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_images = train_ds['image'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
def create_train_state(rng, learning_rate, momentum):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
def train_and_evaluate(learning_rate, momentum,
batch_size, num_epochs) -> train_state.TrainState:
"""Execute model training and evaluation loop.
Args:
config: Hyperparameter configuration for training and evaluation.
workdir: Directory where the tensorboard summaries are written to.
Returns:
The train state (which includes the `.params`).
"""
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, 0.01, 0.9 )
for epoch in range(1, num_epochs + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss, train_accuracy = train_epoch(state, train_ds,
64,
input_rng)
_, test_loss, test_accuracy = apply_model(state, test_ds['image'],
test_ds['label'])
print(
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
% (epoch, train_loss, train_accuracy * 100, test_loss,
test_accuracy * 100))
print('train_loss', train_loss, epoch)
print('train_accuracy', train_accuracy, epoch)
print('test_loss', test_loss, epoch)
print('test_accuracy', test_accuracy, epoch)
return statetrain_and_evaluate(0.01, 0.9, 128, 1)