# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A basic MNIST example using Numpy and JAX.
The primary aim here is simplicity and minimal dependencies.
"""
import array
import gzip
import os
from os import path
import struct
import urllib.request
import numpy as np
= "/tmp/jax_example_data/"
_DATA
def _download(url, filename):
"""Download a url to a file in the JAX data temp directory."""
if not path.exists(_DATA):
os.makedirs(_DATA)= path.join(_DATA, filename)
out_file if not path.isfile(out_file):
urllib.request.urlretrieve(url, out_file)print(f"downloaded {url} to {_DATA}")
def _partial_flatten(x):
"""Flatten all but the first dimension of an ndarray."""
return np.reshape(x, (x.shape[0], -1))
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
def mnist_raw():
"""Download and parse the raw MNIST dataset."""
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
= "https://storage.googleapis.com/cvdf-datasets/mnist/"
base_url
def parse_labels(filename):
with gzip.open(filename, "rb") as fh:
= struct.unpack(">II", fh.read(8))
_ return np.array(array.array("B", fh.read()), dtype=np.uint8)
def parse_images(filename):
with gzip.open(filename, "rb") as fh:
= struct.unpack(">IIII", fh.read(16))
_, num_data, rows, cols return np.array(array.array("B", fh.read()),
=np.uint8).reshape(num_data, rows, cols)
dtype
for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:
+ filename, filename)
_download(base_url
= parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
train_images = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
train_labels = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
test_images = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))
test_labels
return train_images, train_labels, test_images, test_labels
def mnist(permute_train=False):
"""Download, parse and process MNIST data to unit scale and one-hot labels."""
= mnist_raw()
train_images, train_labels, test_images, test_labels
= _partial_flatten(train_images) / np.float32(255.)
train_images = _partial_flatten(test_images) / np.float32(255.)
test_images = _one_hot(train_labels, 10)
train_labels = _one_hot(test_labels, 10)
test_labels
if permute_train:
= np.random.RandomState(0).permutation(train_images.shape[0])
perm = train_images[perm]
train_images = train_labels[perm]
train_labels
return train_images, train_labels, test_images, test_labels
11 Exercise 1 Solution
With thanks to DeepMind for code here.
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A basic MNIST example using Numpy and JAX.
The primary aim here is simplicity and minimal dependencies.
"""
import time
import numpy.random as npr
from jax import jit, grad
from jax.scipy.special import logsumexp
import jax.numpy as jnp
from examples import datasets
def init_random_params(scale, layer_sizes):
# Solution
= random.PRNGKey(0)
key # Split the PRNGKey into two new keys
= random.split(key, 2)
key1, key2 = [(scale * random.normal(key1, (m, n)), scale * random.normal(key2, (n,)))
params for m, n in zip(layer_sizes[:-1], layer_sizes[1:])]
return params
def predict(params, inputs):
= inputs
activations for w, b in params[:-1]:
= jnp.dot(activations, w) + b
outputs = jnp.tanh(outputs)
activations
= params[-1]
final_w, final_b = jnp.dot(activations, final_w) + final_b
logits return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
= batch
inputs, targets = predict(params, inputs)
preds return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
= batch
inputs, targets = jnp.argmax(targets, axis=1)
target_class = jnp.argmax(predict(params, inputs), axis=1)
predicted_class return jnp.mean(predicted_class == target_class)
if __name__ == "__main__":
= [784, 1024, 1024, 10]
layer_sizes = 0.1
param_scale = 0.001
step_size = 10
num_epochs = 128
batch_size
= mnist()
train_images, train_labels, test_images, test_labels = train_images.shape[0]
num_train = divmod(num_train, batch_size)
num_complete_batches, leftover = num_complete_batches + bool(leftover)
num_batches
def data_stream():
= npr.RandomState(0)
rng while True:
= rng.permutation(num_train)
perm for i in range(num_batches):
= perm[i * batch_size:(i + 1) * batch_size]
batch_idx yield train_images[batch_idx], train_labels[batch_idx]
= data_stream()
batches
# Solution
# jit compiling the update function brings the most benefits;
# it does the heavy lifting for the training loop and runs
# many times
@jit
def update(params, batch):
# Solution
= grad(loss)(params, batch)
grads return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
= init_random_params(param_scale, layer_sizes)
params for epoch in range(num_epochs):
= time.time()
start_time for _ in range(num_batches):
= update(params, next(batches))
params = time.time() - start_time
epoch_time
= accuracy(params, (train_images, train_labels))
train_acc = accuracy(params, (test_images, test_labels))
test_acc print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Training set accuracy {train_acc}")
print(f"Test set accuracy {test_acc}")