import jax
from jax import numpy as jnp, random, lax, jit
from flax import linen as nn
X = jnp.ones((1, 10))
Y = jnp.ones((5,))
model = nn.Dense(features=5)
@jit
def predict(params):
return model.apply({'params': params}, X)
@jit
def loss_fn(params):
return jnp.mean(jnp.abs(Y - predict(params)))
@jit
def init_params(rng):
mlp_variables = model.init({'params': rng}, X)
return mlp_variables['params']
# Get initial parameters
params = init_params(jax.random.PRNGKey(42))
print("initial params", params)
# Run SGD.
for i in range(50):
loss, grad = jax.value_and_grad(loss_fn)(params)
print(i, "loss = ", loss, "Yhat = ", predict(params))
lr = 0.03
params = jax.tree_util.tree_map(lambda x, d: x - lr * d, params, grad)