from functools import partial
from flax import linen as nn
import jax.numpy as jnp
from typing import Any, Callable, Sequence, Tuple
7 Let’s build a ResNet!
Now that we have an understanding of JAX, and how to build neural network layers and optimizations in Flax, let’s put our skills to implementing one of the most cited academic publications in machine learning: Deep Residual Learning for Image Recognition.
7.0.1 The idea
The paper authors found that after batchnorm, networks with more layers often performed worse that those with fewer.
(Image from paper linked above by Kaimimg He and others).
The researchers experimented with the idea of adding extra layers as an ‘identity mapping’, which means they have parameters and are trainable, but which return inputs without changing them.
This idea resulted in ‘skip connections’, which leap over convolutions as in this diagram:
The skip connections make the network make the larger architecture easier to train, and prevent overfitting.
7.0.1.1 Residuals
A ‘residual’ is basically a prediction minus the target. ResNet blocks beat most earlier benchmarks because rather than asking them to predict the target, they predict the difference between the target and the prediction. This architecture proved very strong in detecting slight differences in images (is it a wolf or an off-leash Husky running through a dark forest?).
class ResNetBlock(nn.Module):
int
filters: int, int] = (1, 1)
strides: Tuple[
@nn.compact
def __call__(self, x):
= x
residual = nn.Conv(self.filters, (3, 3), self.strides)(x)
y = nn.BatchNorm()(y)
y = nn.relu(y)
y = nn.Conv(self.filters, (3, 3))(y)
y return x + y
if residual.shape != y.shape:
= nn.conv(self.filters, (1, 1),
residual self.strides, name='conv_proj')(residual)
= self.norm(name='norm_proj')(residual)
residual
return self.act(residual + y)
7.0.2 Bottleneck layers
To enable training deeper models without spiking memory and computation use, we can use bottleneck layers. These were also introduced in the original paper as suitable for ResNets with a depth of 50 or more layers.
In our original ResNet layer, we have two convolutions with kernel size 3. Bottleneck layers use a 1 x 1 convolution at the start and end, and a 3 x 3 layer in between.
These improve training when used with deeper models since they allow us to add more filters. Filters mean we can reduce the color channels of images, then restore them – hence their name.
class BottleneckResNetBlock(nn.Module):
int
filters: int, int] = (1, 1)
strides: Tuple[
@nn.compact
def __call__(self, x):
= x
residual = nn.Conv(self.filters, (1, 1))(x)
y = nn.BatchNorm()(y)
y = nn.relu(y)
y = nn.Conv(self.filters, (3, 3), self.strides)(y)
y = nn.BatchNorm()(y)
y = nn.relu(y)
y = nn.Conv(self.filters * 4, (1, 1), self.strides)(y)
y = nn.BatchNorm(scale_init=nn.initializers.zeros_init())(y)
y
if residual.shape != y.shape:
= nn.Conv(self.filters * 4, (1, 1),
residual =(1, 1), name='conv_proj')(residual)
strides= self.BatchNorm(name='norm_proj')(residual)
residual
return nn.relu(residual + y)
7.0.3 Creating the ResNet
Now we have our blocks, we can stack them together to create a fully-fledged ResNet.
Blocks are generally grouped by shape, so if our model has [2,2,2,2] blocks, it means we have four groups of 2 blocks.
We will use the original ResNetBlock
for those < 50 layers, and the BottleneckResNetBlock
for larger architectures.
The various ResNet sizes (ResNet18, ResNet50 etc) simply denote the number of layers.
A ResNet 18’s blocks are [2,2,2,2], so how do we get to 18 layers?
We have 1 initial conv layer, 8 conv layers in residual blocks, a final conv layer, which gives us 10 layers. The remaining 8 are fully connected layers that follow the last conv layer, which typically serve as the classifier head and output predictions.
from functools import partial
class ResNet(nn.Module):
int
num_classes:
block_class: nn.Moduleint]
num_blocks: Sequence[int = 64
filters: = jnp.float32
dtype: Any
@nn.compact
def __call__(self, x, train: bool = True):
= nn.Conv(self.filters, (7, 7), (2, 2),
x =[(3, 3), (3, 3)],
padding=False,
use_bias='conv_init')(x)
name= nn.BatchNorm(name='bn_init')(x, use_running_average=not train)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), (2, 2), padding='SAME')
x
for i, block_size in enumerate(self.num_blocks):
for j in range(block_size):
= (2, 2) if i > 0 and j == 0 else (1, 1)
strides = self.block_class(self.filters * (2**i),
x =strides)(x)
strides
= jnp.mean(x, axis=(1, 2))
x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
x = jnp.asarray(x, self.dtype)
x
return x
= partial(ResNet, num_blocks=[2, 2, 2, 2],
ResNet18 =ResNetBlock)
block_class= partial(ResNet, num_blocks=[3, 4, 6, 3],
ResNet34 =ResNetBlock)
block_class= partial(ResNet, num_blocks=[3, 4, 6, 3],
ResNet50 =BottleneckResNetBlock)
block_class= partial(ResNet, num_blocks=[3, 4, 23, 3],
ResNet101 =BottleneckResNetBlock)
block_class= partial(ResNet, num_blocks=[3, 8, 36, 3],
ResNet152 =BottleneckResNetBlock)
block_class= partial(ResNet, num_blocks=[3, 24, 36, 3],
ResNet200 =BottleneckResNetBlock) block_class