6  Exercise 2: Linear Regression in Flax

Open In Colab

import jax
from jax import numpy as jnp, random, lax, jit
from flax import linen as nn

# TODO: Data preparation
# Create variables X and Y:
# X is a 1 x 10 matrix
# Y is a 1-dimensional array of size 5
# Some references on generating matrices here:
# https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html

X = pass
Y = pass

# TODO: create a model of one Dense layer with 5 features 
# For help:
# https://flax.readthedocs.io/en/latest/guides/flax_basics.html
model = nn.Dense(features=5)

@jit
def predict(params):
  return model.apply({'params': params}, X)

@jit
def loss_fn(params):
  return jnp.mean(jnp.abs(Y - predict(params)))

# Initialize the model with random values
# use random number generator ('rng') as input
# to initialize params based on input shape of 'X'.
@jit
def init_params(rng):
  mlp_variables = model.init({'params': rng}, X)
  return mlp_variables['params']

# TODO
# use the init_params function and 
# jax.random to initalize random params
# using PRNGKey
params = None
print("initial params", params)

# Run SGD.
for i in range(50):
  # TODO use jax transformations to extract the loss value
  # and gradients of the loss with respect to the params
  loss, grad = pass
  print(i, "loss = ", loss, "Yhat = ", predict(params))
  lr = 0.03
  params = jax.tree_util.tree_map(lambda x, d: x - lr * d, params, grad)