5  Flax Foundations

Open In Colab

Efficient and flexible model development

By combining JAX’s auto-differentiation and Flax’s modular design, developers can easily construct and train state-of-the-art deep learning models. JAX/Flax traces pure functions and compiles for GPU and TPU accelerators.

import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn
# Here's a single dense layer that takes a number of features as input
model = nn.Dense(features=5)

key1, key2 = random.split(random.PRNGKey(0))
# Dummy input data
x = random.normal(key1, (10,))
# Initialize the model
params = model.init(key2, x)
# Forward pass
model.apply(params, x)
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([-1.3721193 ,  0.61131495,  0.6442836 ,  2.2192965 , -1.1271116 ],      dtype=float32)

Note we only mention to Flax the number of features for the output of the model, rather than specifying the size of the input. Flax works out the correct kernel size for us!

Let’s take a look at the pytree:

# Check output shapes
jax.tree_util.tree_map(lambda x: x.shape, params)
FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

Notice the parameters are stored in a FrozenDict, which prevents any mutation of the values.

import jax
import jax.numpy as jnp
import flax.linen as nn

# Dummy data
inputs = jnp.array([[0.2, 0.3, 0.4], [0.1, 0.2, 0.3]])
targets = jnp.array([[0.5], [0.8]])

# Simple feedforward neural network
class SimpleNetwork(nn.Module):
    hidden_size: int
    output_size: int

    def setup(self):
        self.dense1 = nn.Dense(self.hidden_size)
        self.dense2 = nn.Dense(self.output_size)

    def __call__(self, x):
        x = self.dense1(x)
        x = nn.relu(x)
        x = self.dense2(x)
        return x

# Initialization
hidden_size = 16
output_size = 1
rng = jax.random.PRNGKey(0)
model = SimpleNetwork(hidden_size, output_size)
params = model.init(rng, inputs)
tree = jax.tree_util.tree_map(lambda inputs: inputs.shape, params) # Checking output shapes
print(tree)

# Forward pass
predictions = model.apply(params, inputs)

print(f"Inputs: \n{inputs}")
print(f"\nPredictions: \n{predictions}")
print(f"\nTarget data: \n{targets}")
FrozenDict({
    params: {
        dense1: {
            bias: (16,),
            kernel: (3, 16),
        },
        dense2: {
            bias: (1,),
            kernel: (16, 1),
        },
    },
})
Inputs: 
[[0.2 0.3 0.4]
 [0.1 0.2 0.3]]

Predictions: 
[[-0.01026188]
 [-0.01458298]]

Target data: 
[[0.5]
 [0.8]]

In this example, we defined our model explicitly using setup. We can also define architecrures using nn.compact, which allows us to define a modulea s a single method. This can lead to cleaner code if you are writing custom layers.

Here’s our SimpleNetwork again, using setup.

class SimpleNetwork(nn.Module):
    hidden_size: int
    output_size: int

    def setup(self):
        self.dense1 = nn.Dense(self.hidden_size)
        self.dense2 = nn.Dense(self.output_size)

    def __call__(self, x):
        x = self.dense1(x)
        x = nn.relu(x)
        x = self.dense2(x)
        return x

And using nn.compact:

class SimpleNetwork(nn.Module):
  hidden_size: int
  output_size: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(hidden_size, name="dense1")(x)
    x = nn.relu(x)
    x = nn.Dense(output_size, name="dense2")(x)
    return x

If you are porting models from PyTorch, or prefer explicit definition and separation of submodules, setup may suit. nn.compact may be best for reducing duplication, writing code that looks closer to mathematical notation, or if you are using shape inference (parameters dependant on shapes of inputs unknown at initialization).

5.0.1 Flax modules

Flax it easy to incorporate training techniques such as batch normalization and learning rate scheduling via the flax.linen.Module.

Here’s our simple multi-layer perceptron again:

class SimpleNetwork(nn.Module):
  hidden_size: int
  output_size: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(hidden_size, name="dense1")(x)
    x = nn.relu(x)
    x = nn.Dense(output_size, name="dense2")(x)
    return x

Batch normalization is a regularization technique which computes running averages over feature dimensions. This speeds up training cycles and improves convergence. To apply batch normalization, we call upon flax.linen.BatchNorm.

class SimpleNetwork(nn.Module):
  hidden_size: int
  output_size: int

  @nn.compact
  def __call__(self, x, train: bool):
    x = nn.Dense(hidden_size, name="dense1")(x)
    x = nn.BatchNorn(use_running_average=not train)(x)
    x = nn.relu(x)
    x = nn.Dense(output_size, name="dense2")(x)
    return x

5.0.2 Dropout

Dropout is another (stochastic) regularization technique that randomly removes units in a network to improve reduce overfitting and improve generalization.

Dropout requires our PRNG skills to endure it is a random operation.

When splitting a key, we can simply split into three keys, granting the third for flax.linen.dropout.

key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=key, num=3)

Then add the module to our model:

class SimpleNetwork(nn.Module):
  hidden_size: int
  output_size: int

  @nn.compact
  def __call__(self, x, train: bool):
    x = nn.Dense(hidden_size, name="dense1")(x)
    x = nn.Dropout(rate=0.5, deterministic=not train)(x)
    x = nn.BatchNorm(use_running_average=not train)(x)
    x = nn.relu(x)
    x = nn.Dense(output_size, name="dense2")(x)
    return x

We can then initialize the model:

simple_net = SimpleNetwork(hidden_size=5, output_size=1)
x = jnp.empty((3, 4, 4, 5, 5))
# Dropout is enabled via `deterministic=True`.
variables = simple_net.init(params_key, x, train=False)
params = variables['params']

5.0.3 Train states

A “train state” is the mutable state of a model during training, including properties such as its parameters (weights) and optimizer state.

The train state is typically represented as an instance of the flax.training.TrainState class, which encapsulates and provides methods to update the state.

One of the features of JAX/Flax is its functional programming characteristic of immutability. Models are updates are purely functional, enabling model parallelism and efficient training.

# Example, will not run

def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

5.0.4 Optax

Optax is a gradient processing and optimization package. It is generally used with Flax as follows:

Create an optimizer state from parameters using any optimization method (eg optax.rmsprop). Compute loss gradients using value_and_grad(). Call the Optax update function to update the internal optimizer state to work out how to tweak the parameters. Use apply_updates to apply update the to the parameters.

For example (will not run):

import optax

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)
loss_grad_func = jax.value_and_grad(mse)

for i in range(10):
  loss, grads = loss_grad_func(params, x_samples, y_samples)
  updates, optimizer_state = optimizer.update(grads, optimizer_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss)

MNIST Example

from absl import logging
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds


class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x


@jax.jit
def apply_model(state, images, labels):
  """Computes gradients, loss and accuracy for a single batch."""
  def loss_fn(params):
    logits = state.apply_fn({'params': params}, images)
    one_hot = jax.nn.one_hot(labels, 10)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy


@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)


def train_epoch(state, train_ds, batch_size, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []

  for perm in perms:
    batch_images = train_ds['image'][perm, ...]
    batch_labels = train_ds['label'][perm, ...]
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
    state = update_model(state, grads)
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy


def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds


def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)


def train_and_evaluate(learning_rate, momentum,
                       batch_size, num_epochs) -> train_state.TrainState:
  """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
  train_ds, test_ds = get_datasets()
  rng = jax.random.PRNGKey(0)

  rng, init_rng = jax.random.split(rng)
  state = create_train_state(init_rng, 0.01, 0.9 )

  for epoch in range(1, num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_accuracy = train_epoch(state, train_ds,
                                                    64,
                                                    input_rng)
    _, test_loss, test_accuracy = apply_model(state, test_ds['image'],
                                              test_ds['label'])

    print(
        'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
        % (epoch, train_loss, train_accuracy * 100, test_loss,
           test_accuracy * 100))

    print('train_loss', train_loss, epoch)
    print('train_accuracy', train_accuracy, epoch)
    print('test_loss', test_loss, epoch)
    print('test_accuracy', test_accuracy, epoch)

  return state
train_and_evaluate(0.01, 0.9, 128, 1)