Implement linear model as a neural network in different jax-based neural network libraries: equinox, nnx, flax (deprecated in favour of the nnx api where we can call model(x) – no need for a separate apply method), penzai (stale project, but had some useful contributions such as treescope and named array axis).
nnx
from flax import nnx
class LinearRegression(nnx.Module):
def __init__(self, in_size: int, out_size: int, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_size, out_size, rngs=rngs)
def __call__(self, x: jax.Array):
y_pred = self.linear(x)
return y_predoptimizer = nnx.Optimizer(model, optax.adam(1e-3))
@nnx.jit # automatic state management
def train_step(model, optimizer, x, y):
def loss_fn(model: MLP):
y_pred = model(x)
return optax.l2_loss(y_pred, y).sum()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # inplace updates
return lossequinox
eqx.Module is defined for a single input vector.
- mini-batching must be done via an external
vmapcall - can apply
modeldirectly like a function
import equinox as eqx
# equally eqx.nn.linear
class LinearLayer(eqx.Module):
weight: jax.Array
bias: jax.Array
def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
def __call__(self, x):
return self.weight @ x + self.bias
class LinearRegression(eqx.Module):
def __init__(self, in_size: int, out_size: int, key: jax.random.PRNGKey):
self.linear = LinearLayer(in_size, out_size, key=key)
def __call__(self, x: jax.Array):
y_pred = self.linear(x)
return y_predoptimizer = optax.adam(learning_rate=0.001)
# it only makes sense to train the arrays in our model,
# so filter out everything else using `eqx.filter`
params = optim.init(eqx.filter(model, eqx.is_array))
@eqx.filter_jit
def train_step(key, params, optimizer, x, y):
def loss_fn(params):
y_pred = jax.vmap(model)(x) # vectorise the model over a batch of data
return optax.l2_loss(y_pred, y).sum()
loss, grads = eqx.filter_value_and_grad(loss_fn)(params)
updates, params = optimizer.update(grads, params)
model = eqx.apply_updates(model, updates)
return model