import jaxfrom jax import numpy as jnp, random, lax, jitfrom flax import linen as nn# TODO: Data preparation# Create variables X and Y:# X is a 1 x 10 matrix# Y is a 1-dimensional array of size 5# Some references on generating matrices here:# https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.htmlX =passY =pass# TODO: create a model of one Dense layer with 5 features # For help:# https://flax.readthedocs.io/en/latest/guides/flax_basics.htmlmodel = nn.Dense(features=5)@jitdef predict(params):return model.apply({'params': params}, X)@jitdef loss_fn(params):return jnp.mean(jnp.abs(Y - predict(params)))# Initialize the model with random values# use random number generator ('rng') as input# to initialize params based on input shape of 'X'.@jitdef init_params(rng): mlp_variables = model.init({'params': rng}, X)return mlp_variables['params']# TODO# use the init_params function and # jax.random to initalize random params# using PRNGKeyparams =Noneprint("initial params", params)# Run SGD.for i inrange(50):# TODO use jax transformations to extract the loss value# and gradients of the loss with respect to the params loss, grad =passprint(i, "loss = ", loss, "Yhat = ", predict(params)) lr =0.03 params = jax.tree_util.tree_map(lambda x, d: x - lr * d, params, grad)