import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn
5 Flax Foundations
Efficient and flexible model development
By combining JAX’s auto-differentiation and Flax’s modular design, developers can easily construct and train state-of-the-art deep learning models. JAX/Flax traces pure functions and compiles for GPU and TPU accelerators.
# Here's a single dense layer that takes a number of features as input
= nn.Dense(features=5)
model
= random.split(random.PRNGKey(0))
key1, key2 # Dummy input data
= random.normal(key1, (10,))
x # Initialize the model
= model.init(key2, x)
params # Forward pass
apply(params, x) model.
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([-1.3721193 , 0.61131495, 0.6442836 , 2.2192965 , -1.1271116 ], dtype=float32)
Note we only mention to Flax the number of features for the output of the model, rather than specifying the size of the input. Flax works out the correct kernel size for us!
Let’s take a look at the pytree:
# Check output shapes
lambda x: x.shape, params) jax.tree_util.tree_map(
FrozenDict({
params: {
bias: (5,),
kernel: (10, 5),
},
})
Notice the parameters are stored in a FrozenDict
, which prevents any mutation of the values.
import jax
import jax.numpy as jnp
import flax.linen as nn
# Dummy data
= jnp.array([[0.2, 0.3, 0.4], [0.1, 0.2, 0.3]])
inputs = jnp.array([[0.5], [0.8]])
targets
# Simple feedforward neural network
class SimpleNetwork(nn.Module):
int
hidden_size: int
output_size:
def setup(self):
self.dense1 = nn.Dense(self.hidden_size)
self.dense2 = nn.Dense(self.output_size)
def __call__(self, x):
= self.dense1(x)
x = nn.relu(x)
x = self.dense2(x)
x return x
# Initialization
= 16
hidden_size = 1
output_size = jax.random.PRNGKey(0)
rng = SimpleNetwork(hidden_size, output_size)
model = model.init(rng, inputs)
params = jax.tree_util.tree_map(lambda inputs: inputs.shape, params) # Checking output shapes
tree print(tree)
# Forward pass
= model.apply(params, inputs)
predictions
print(f"Inputs: \n{inputs}")
print(f"\nPredictions: \n{predictions}")
print(f"\nTarget data: \n{targets}")
FrozenDict({
params: {
dense1: {
bias: (16,),
kernel: (3, 16),
},
dense2: {
bias: (1,),
kernel: (16, 1),
},
},
})
Inputs:
[[0.2 0.3 0.4]
[0.1 0.2 0.3]]
Predictions:
[[-0.01026188]
[-0.01458298]]
Target data:
[[0.5]
[0.8]]
In this example, we defined our model explicitly using setup
. We can also define architecrures using nn.compact
, which allows us to define a modulea s a single method. This can lead to cleaner code if you are writing custom layers.
Here’s our SimpleNetwork again, using setup
.
class SimpleNetwork(nn.Module):
int
hidden_size: int
output_size:
def setup(self):
self.dense1 = nn.Dense(self.hidden_size)
self.dense2 = nn.Dense(self.output_size)
def __call__(self, x):
= self.dense1(x)
x = nn.relu(x)
x = self.dense2(x)
x return x
And using nn.compact
:
class SimpleNetwork(nn.Module):
int
hidden_size: int
output_size:
@nn.compact
def __call__(self, x):
= nn.Dense(hidden_size, name="dense1")(x)
x = nn.relu(x)
x = nn.Dense(output_size, name="dense2")(x)
x return x
If you are porting models from PyTorch, or prefer explicit definition and separation of submodules, setup
may suit. nn.compact
may be best for reducing duplication, writing code that looks closer to mathematical notation, or if you are using shape inference (parameters dependant on shapes of inputs unknown at initialization).
5.0.1 Flax modules
Flax it easy to incorporate training techniques such as batch normalization and learning rate scheduling via the flax.linen.Module
.
Here’s our simple multi-layer perceptron again:
class SimpleNetwork(nn.Module):
int
hidden_size: int
output_size:
@nn.compact
def __call__(self, x):
= nn.Dense(hidden_size, name="dense1")(x)
x = nn.relu(x)
x = nn.Dense(output_size, name="dense2")(x)
x return x
Batch normalization is a regularization technique which computes running averages over feature dimensions. This speeds up training cycles and improves convergence. To apply batch normalization, we call upon flax.linen.BatchNorm
.
class SimpleNetwork(nn.Module):
int
hidden_size: int
output_size:
@nn.compact
def __call__(self, x, train: bool):
= nn.Dense(hidden_size, name="dense1")(x)
x = nn.BatchNorn(use_running_average=not train)(x)
x = nn.relu(x)
x = nn.Dense(output_size, name="dense2")(x)
x return x
5.0.2 Dropout
Dropout is another (stochastic) regularization technique that randomly removes units in a network to improve reduce overfitting and improve generalization.
Dropout requires our PRNG skills to endure it is a random operation.
When splitting a key, we can simply split into three keys, granting the third for flax.linen.dropout
.
= jax.random.PRNGKey(seed=0)
key = jax.random.split(key=key, num=3) main_key, params_key, dropout_key
Then add the module to our model:
class SimpleNetwork(nn.Module):
int
hidden_size: int
output_size:
@nn.compact
def __call__(self, x, train: bool):
= nn.Dense(hidden_size, name="dense1")(x)
x = nn.Dropout(rate=0.5, deterministic=not train)(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.relu(x)
x = nn.Dense(output_size, name="dense2")(x)
x return x
We can then initialize the model:
= SimpleNetwork(hidden_size=5, output_size=1)
simple_net = jnp.empty((3, 4, 4, 5, 5))
x # Dropout is enabled via `deterministic=True`.
= simple_net.init(params_key, x, train=False)
variables = variables['params'] params
5.0.3 Train states
A “train state” is the mutable state of a model during training, including properties such as its parameters (weights) and optimizer state.
The train state is typically represented as an instance of the flax.training.TrainState
class, which encapsulates and provides methods to update the state.
One of the features of JAX/Flax is its functional programming characteristic of immutability. Models are updates are purely functional, enabling model parallelism and efficient training.
# Example, will not run
def create_train_state(rng, learning_rate, momentum):
"""Creates initial `TrainState`."""
= CNN()
cnn = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
params = optax.sgd(learning_rate, momentum)
tx return train_state.TrainState.create(
=cnn.apply, params=params, tx=tx)
apply_fn
5.0.4 Optax
Optax is a gradient processing and optimization package. It is generally used with Flax as follows:
Create an optimizer state from parameters using any optimization method (eg optax.rmsprop
). Compute loss gradients using value_and_grad()
. Call the Optax update
function to update the internal optimizer state to work out how to tweak the parameters. Use apply_updates
to apply update the to the parameters.
For example (will not run):
import optax
= optax.adam(learning_rate=learning_rate)
optimizer = optimizer.init(params)
optimizer_state = jax.value_and_grad(mse)
loss_grad_func
for i in range(10):
= loss_grad_func(params, x_samples, y_samples)
loss, grads = optimizer.update(grads, optimizer_state)
updates, optimizer_state = optax.apply_updates(params, updates)
params if i % 10 == 0:
print('Loss step {}: '.format(i), loss)
MNIST Example
from absl import logging
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
= nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x return x
@jax.jit
def apply_model(state, images, labels):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(params):
= state.apply_fn({'params': params}, images)
logits = jax.nn.one_hot(labels, 10)
one_hot = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
loss return loss, logits
= jax.value_and_grad(loss_fn, has_aux=True)
grad_fn = grad_fn(state.params)
(loss, logits), grads = jnp.mean(jnp.argmax(logits, -1) == labels)
accuracy return grads, loss, accuracy
@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)
def train_epoch(state, train_ds, batch_size, rng):
"""Train for a single epoch."""
= len(train_ds['image'])
train_ds_size = train_ds_size // batch_size
steps_per_epoch
= jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
perms
= []
epoch_loss = []
epoch_accuracy
for perm in perms:
= train_ds['image'][perm, ...]
batch_images = train_ds['label'][perm, ...]
batch_labels = apply_model(state, batch_images, batch_labels)
grads, loss, accuracy = update_model(state, grads)
state
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)= np.mean(epoch_loss)
train_loss = np.mean(epoch_accuracy)
train_accuracy return state, train_loss, train_accuracy
def get_datasets():
"""Load MNIST train and test datasets into memory."""
= tfds.builder('mnist')
ds_builder
ds_builder.download_and_prepare()= tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
test_ds 'image'] = jnp.float32(train_ds['image']) / 255.
train_ds['image'] = jnp.float32(test_ds['image']) / 255.
test_ds[return train_ds, test_ds
def create_train_state(rng, learning_rate, momentum):
"""Creates initial `TrainState`."""
= CNN()
cnn = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
params = optax.sgd(learning_rate, momentum)
tx return train_state.TrainState.create(
=cnn.apply, params=params, tx=tx)
apply_fn
def train_and_evaluate(learning_rate, momentum,
-> train_state.TrainState:
batch_size, num_epochs) """Execute model training and evaluation loop.
Args:
config: Hyperparameter configuration for training and evaluation.
workdir: Directory where the tensorboard summaries are written to.
Returns:
The train state (which includes the `.params`).
"""
= get_datasets()
train_ds, test_ds = jax.random.PRNGKey(0)
rng
= jax.random.split(rng)
rng, init_rng = create_train_state(init_rng, 0.01, 0.9 )
state
for epoch in range(1, num_epochs + 1):
= jax.random.split(rng)
rng, input_rng = train_epoch(state, train_ds,
state, train_loss, train_accuracy 64,
input_rng)= apply_model(state, test_ds['image'],
_, test_loss, test_accuracy 'label'])
test_ds[
print(
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
% (epoch, train_loss, train_accuracy * 100, test_loss,
* 100))
test_accuracy
print('train_loss', train_loss, epoch)
print('train_accuracy', train_accuracy, epoch)
print('test_loss', test_loss, epoch)
print('test_accuracy', test_accuracy, epoch)
return state
0.01, 0.9, 128, 1) train_and_evaluate(