import array
import gzip
import os
from os import path
import struct
import urllib.request
import numpy as np
= "/tmp/jax_example_data/"
_DATA
def _download(url, filename):
"""Download a url to a file in the JAX data temp directory."""
if not path.exists(_DATA):
os.makedirs(_DATA)= path.join(_DATA, filename)
out_file if not path.isfile(out_file):
urllib.request.urlretrieve(url, out_file)print(f"downloaded {url} to {_DATA}")
def _partial_flatten(x):
"""Flatten all but the first dimension of an ndarray."""
return np.reshape(x, (x.shape[0], -1))
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
def mnist_raw():
"""Download and parse the raw MNIST dataset."""
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
= "https://storage.googleapis.com/cvdf-datasets/mnist/"
base_url
def parse_labels(filename):
with gzip.open(filename, "rb") as fh:
= struct.unpack(">II", fh.read(8))
_ return np.array(array.array("B", fh.read()), dtype=np.uint8)
def parse_images(filename):
with gzip.open(filename, "rb") as fh:
= struct.unpack(">IIII", fh.read(16))
_, num_data, rows, cols return np.array(array.array("B", fh.read()),
=np.uint8).reshape(num_data, rows, cols)
dtype
for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:
+ filename, filename)
_download(base_url
= parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
train_images = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
train_labels = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
test_images = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))
test_labels
return train_images, train_labels, test_images, test_labels
def mnist(permute_train=False):
"""Download, parse and process MNIST data to unit scale and one-hot labels."""
= mnist_raw()
train_images, train_labels, test_images, test_labels
= _partial_flatten(train_images) / np.float32(255.)
train_images = _partial_flatten(test_images) / np.float32(255.)
test_images = _one_hot(train_labels, 10)
train_labels = _one_hot(test_labels, 10)
test_labels
if permute_train:
= np.random.RandomState(0).permutation(train_images.shape[0])
perm = train_images[perm]
train_images = train_labels[perm]
train_labels
return train_images, train_labels, test_images, test_labels
4 Exercise 1: MNIST in JAX
Code mostly ported with thanks from DeepMind’s examples.
Using minimal dependencies and pure JAX functions, we train a simple neural network to classify MNIST digits.
Firstly, some data loading functions. JAX can also leverage both TensorFlow and PyTorch data loading capabilites.
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A basic MNIST example using Numpy and JAX.
"""
import time
import numpy.random as npr
from jax import jit, grad, random
from jax.scipy.special import logsumexp
import jax.numpy as jnp
from examples import datasets
# TODO: If you could only @jit one of these functions,
# which would be the best candidate?
# Add the @jit decorator to the function of your choice.
def init_random_params(scale, layer_sizes):
# TODO Initialize a random PRNGKey
= pass
key # TODO Split the PRNGKey into two new keys
= pass
key1, key2 = [(scale * random.normal(key1, (m, n)), scale * random.normal(key2, (n,)))
params for m, n in zip(layer_sizes[:-1], layer_sizes[1:])]
print(params)
return params
def predict(params, inputs):
= inputs
activations for w, b in params[:-1]:
= jnp.dot(activations, w) + b
outputs = jnp.tanh(outputs)
activations
= params[-1]
final_w, final_b = jnp.dot(activations, final_w) + final_b
logits return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
= batch
inputs, targets = predict(params, inputs)
preds return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
= batch
inputs, targets = jnp.argmax(targets, axis=1)
target_class = jnp.argmax(predict(params, inputs), axis=1)
predicted_class return jnp.mean(predicted_class == target_class)
if __name__ == "__main__":
= [784, 1024, 1024, 10]
layer_sizes = 0.1
param_scale = 0.001
step_size = 10
num_epochs = 128
batch_size
= mnist()
train_images, train_labels, test_images, test_labels = train_images.shape[0]
num_train = divmod(num_train, batch_size)
num_complete_batches, leftover = num_complete_batches + bool(leftover)
num_batches
def data_stream():
= npr.RandomState(0)
rng while True:
= rng.permutation(num_train)
perm for i in range(num_batches):
= perm[i * batch_size:(i + 1) * batch_size]
batch_idx yield train_images[batch_idx], train_labels[batch_idx]
= data_stream()
batches
def update(params, batch):
# TODO: use JAX's transformations to
# find the gradients of the loss function w.r.t. params, batch
# Replace `pass` with your code
= pass
grads return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
= init_random_params(param_scale, layer_sizes)
params for epoch in range(num_epochs):
= time.time()
start_time for _ in range(num_batches):
= update(params, next(batches))
params = time.time() - start_time
epoch_time
= accuracy(params, (train_images, train_labels))
train_acc = accuracy(params, (test_images, test_labels))
test_acc print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Training set accuracy {train_acc}")
print(f"Test set accuracy {test_acc}")