Implement linear model as a neural network in different jax-based neural network libraries: flax, equinox, nnx, penzai.
Visualise using treescope.
flax
- pytree of parameters for the model
- parameters are not stored with the models themselves
- cannot directly call
model(x)
__call__
function wrapped up in theapply
function- need to call
model.init(key, x)
import flax.linen as nn
class LinearRegression(nn.Module):
features: int
@nn.compact
def __call__(self, x):
y_pred = nn.Dense(self.features)(x)
return y_pred
optimizer = optax.adam(learning_rate=0.001)
# better to use flax.training.train_state
@jax.jit
def train_step(key, params, optimizer, x, y):
def loss_fn(params):
y_pred = model.apply(params, x)
return optax.l2_loss(y_pred, y).sum()
loss, grads = jax.value_and_grad(loss_fn)(params)
updates, params = optimizer.update(grads, params)
params = optax.apply_updates(params, updates)
return params
nnx
- express models using regular python objects, which are modelled as pygraphs (instead of pytrees)
- module itself holds the state directly
- dynamic state is usually stored in
nnx.Param
- access using
graphdef, params, rngs = nnx.split(model, nnx.Param, nnx.RngState)
to get theparams
- parameters were already initialised during model instantiation so no need to call
model.init()
- can call
model(x)
, no need for a separateapply
method
- dynamic state is usually stored in
- using
nnx.split
andnnx.merge
to operate on module (from gpjax)
graphdef, state = nnx.split(posterior) # state behaves like a pytree
updated_state = jax.tree_utils.tree_map(lambda x: x + 1, state) # increment each parameter's value by 1
updated_posterior = nnx.merge(graphdef, updated_state) # reconstruct the posterior distribution using the updated state
from flax import nnx
class LinearRegression(nnx.Module):
def __init__(self, in_size: int, out_size: int, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_size, out_size, rngs=rngs)
def __call__(self, x: jax.Array):
y_pred = self.linear(x)
return y_pred
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
@nnx.jit # automatic state management
def train_step(model, optimizer, x, y):
def loss_fn(model: MLP):
y_pred = model(x)
return optax.l2_loss(y_pred, y).sum()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # inplace updates
return loss
equinox
- Equinox represent models as pytrees using
eqx.Module
- then use any jax transformation (such as
jit
,grad
jax.vmap) or other jax code - class-like behaviour can be obtained by avoiding extra abstractions and simply exploiting pytrees
- then use any jax transformation (such as
- supports arbitrary python objects for its leaves (e.g. activation function)
- need to use filtering (
@eqx.filter_jit
,@eqx.filter_grad
etc) to operate on:params
, the leaves that are arrays; and notstatic
, everything else that cannot be operated on
- need to use filtering (
eqx.Module
is defined for a single input vector- mini-batching must be done via an external
vmap
call - can apply
model
directly like a function
- mini-batching must be done via an external
import equinox as eqx
# equally eqx.nn.linear
class LinearLayer(eqx.Module):
weight: jax.Array
bias: jax.Array
def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
def __call__(self, x):
return self.weight @ x + self.bias
class LinearRegression(eqx.Module):
def __init__(self, in_size: int, out_size: int, key: jax.random.PRNGKey):
self.linear = LinearLayer(in_size, out_size, key=key)
def __call__(self, x: jax.Array):
y_pred = self.linear(x)
return y_pred
optimizer = optax.adam(learning_rate=0.001)
# it only makes sense to train the arrays in our model,
# so filter out everything else using `eqx.filter`
params = optim.init(eqx.filter(model, eqx.is_array))
@eqx.filter_jit
def train_step(key, params, optimizer, x, y):
def loss_fn(params):
y_pred = jax.vmap(model)(x) # vectorise the model over a batch of data
return optax.l2_loss(y_pred, y).sum()
loss, grads = eqx.filter_value_and_grad(loss_fn)(params)
updates, params = optimizer.update(grads, params)
model = eqx.apply_updates(model, updates)
return model
eqx.tree_at mutates
a particular leaf or leaves of a pytreeequinox.tree_serialise_leaves
to save leaves of a pytree to a file
penzai
API is in flux but some useful concepts such as named arrays and treescope for visualising arrays of a pytree. Currently best used for inspecting trained neural networks rather than building and training.
- Penzai models are registered as jax pytree nodes (similar to Equinox)
- so any Penzai model can be traversed using
jax.tree_util
or other jax transformation - forward pass of a model, call it like a function
- so any Penzai model can be traversed using
- Penzai model objects are immutable
- modifications to Penzai models always occur by making a modified copy of the model
- Each Penzai model tree has two types of leaves:
- jax arrays and scalars, which are immutable and often represent hyperparameters, and
- variable objects, which can be modified:
Parameter
s, usually modified by optimisers (not by the model),StateVariable
s, usually updated as the model runs.
# each layer has the same signature
# exactly one positional argument – the input from the previous layer
class Layer(pz.Struct, abc.ABC):
@abc.abstractmethod
def __call__(self, argument: Any, /, **side_inputs) -> Any:
...
# for example
@struct.pytree_dataclass
class Attention(Layer):
input_to_query: Layer
input_to_key: Layer
input_to_value: Layer
query_key_to_attn: Layer
attn_value_to_output: Layer
def __call__(self, x: NamedArray, **side_inputs) -> NamedArray:
query = self.input_to_query(x, **side_inputs)
key = self.input_to_key(x, **side_inputs)
value = self.input_to_value(x, **side_inputs)
attn = self.query_key_to_attn((query, key), **side_inputs)
output = self.attn_value_to_output((attn, value), **side_inputs)
return output
@pz.pytree_dataclass
class LinearRegression(pz.nn.Sequential):
sublayers: list[pz.nn.Layer]
@classmethod
def from_config(
cls,
name: str,
init_base_rng: jax.Array | None,
feature_sizes: Sequence[int],
features_axis: str = "features",
) -> "LinearRegression":
sublayers = []
for i in range(len(feature_sizes) - 1):
sublayers.append(Linear.from_config(
name=f"{name}/block_{i}/linear",
init_base_rng=init_base_rng,
in_features=feature_sizes[i],
out_features=feature_sizes[i + 1],
features_axis=features_axis,
))
sublayers.append(Bias.from_config(
name=f"{name}/block_{i}/bias",
init_base_rng=init_base_rng,
features=feature_sizes[i + 1],
features_axis=features_axis,
))
return cls(sublayers)
named axes
pz.nx.NamedArray
class wraps an ordinary array, and assigns each axis to either a position or a name (but not both)- two shapes:
.positional_shape
attribute is the sequence of dimension sizes for the positional axes in the array.named_shape
attribute is a dictionary mapping axis names to their dimension sizes
pz.nx.wrap(jnp.ones([3, 4, 5])).tag("foo", "bar", "baz")
my_named_array.untag("bar", "foo", "baz").unwrap()
- operations on NamedArrays always act only on the positional axes
- untag, move those axes to the
.positional_shape
- operate, calling
pz.nx.nmap(some_positional_op)(...args...)
- retag,
.tag
to move the resulting axes from the.positional_shape
back to the.named_shape
- untag, move those axes to the
def named_dot(x: pz.nx.NamedArray, features_axis: str) -> pz.nx.NamedArray:
pos_x = x.untag(features_axis)
pos_kernel = kernel.value.untag("out_features", "in_features")
pos_y = pz.nx.nmap(jnp.dot)(pos_kernel, pos_x)
return pos_y.tag(features_axis)