Implement linear model as a neural network in different jax-based neural network libraries: flax, equinox, nnx, penzai.

Visualise using treescope.

flax

  • pytree of parameters for the model
    • parameters are not stored with the models themselves
    • cannot directly call model(x)
      • __call__ function wrapped up in the apply function
      • need to call model.init(key, x)
import flax.linen as nn
 
class LinearRegression(nn.Module):
	features: int
	
	@nn.compact
	def __call__(self, x):
		y_pred = nn.Dense(self.features)(x)
		return y_pred
optimizer = optax.adam(learning_rate=0.001)
 
# better to use flax.training.train_state
@jax.jit
def train_step(key, params, optimizer, x, y):
	def loss_fn(params):
		y_pred = model.apply(params, x)
		return optax.l2_loss(y_pred, y).sum()
 
	loss, grads = jax.value_and_grad(loss_fn)(params)
	updates, params = optimizer.update(grads, params)
	params = optax.apply_updates(params, updates)
  return params

nnx

  • express models using regular python objects, which are modelled as pygraphs (instead of pytrees)
  • module itself holds the state directly
    • dynamic state is usually stored in nnx.Param
    • access using graphdef, params, rngs = nnx.split(model, nnx.Param, nnx.RngState) to get the params
    • parameters were already initialised during model instantiation so no need to call model.init()
    • can call model(x), no need for a separate apply method
  • using nnx.split and nnx.merge to operate on module (from gpjax)
graphdef, state = nnx.split(posterior) # state behaves like a pytree
updated_state = jax.tree_utils.tree_map(lambda x: x + 1, state) # increment each parameter's value by 1
updated_posterior = nnx.merge(graphdef, updated_state) # reconstruct the posterior distribution using the updated state
from flax import nnx
 
class LinearRegression(nnx.Module):
	def __init__(self, in_size: int, out_size: int, *, rngs: nnx.Rngs):
		self.linear = nnx.Linear(in_size, out_size, rngs=rngs)
 
	def __call__(self, x: jax.Array):
		y_pred = self.linear(x)
		return y_pred
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
 
@nnx.jit  # automatic state management
def train_step(model, optimizer, x, y):
	def loss_fn(model: MLP):
		y_pred = model(x)
		return optax.l2_loss(y_pred, y).sum()
		
	loss, grads = nnx.value_and_grad(loss_fn)(model)
	optimizer.update(grads)  # inplace updates
 
	return loss

equinox

  • Equinox represent models as pytrees using eqx.Module
    • then use any jax transformation (such as jit, grad jax.vmap) or other jax code
    • class-like behaviour can be obtained by avoiding extra abstractions and simply exploiting pytrees
  • supports arbitrary python objects for its leaves (e.g. activation function)
    • need to use filtering (@eqx.filter_jit, @eqx.filter_grad etc) to operate on:
      • params, the leaves that are arrays; and not
      • static, everything else that cannot be operated on
  • eqx.Module is defined for a single input vector
    • mini-batching must be done via an external vmap call
    • can apply model directly like a function
import equinox as eqx
 
# equally eqx.nn.linear
class LinearLayer(eqx.Module):
	weight: jax.Array
	bias: jax.Array
	def __init__(self, in_size, out_size, key):
		wkey, bkey = jax.random.split(key)
		self.weight = jax.random.normal(wkey, (out_size, in_size))
		self.bias = jax.random.normal(bkey, (out_size,))
		
	def __call__(self, x):
		return self.weight @ x + self.bias
 
class LinearRegression(eqx.Module):
	def __init__(self, in_size: int, out_size: int, key: jax.random.PRNGKey):
		self.linear = LinearLayer(in_size, out_size, key=key)
 
	def __call__(self, x: jax.Array):
		y_pred = self.linear(x)
		return y_pred
optimizer = optax.adam(learning_rate=0.001)
 
# it only makes sense to train the arrays in our model,
# so filter out everything else using `eqx.filter`
params = optim.init(eqx.filter(model, eqx.is_array))
 
@eqx.filter_jit
def train_step(key, params, optimizer, x, y):
	def loss_fn(params):
		y_pred = jax.vmap(model)(x) # vectorise the model over a batch of data
		return optax.l2_loss(y_pred, y).sum()
 
	loss, grads = eqx.filter_value_and_grad(loss_fn)(params)
	updates, params = optimizer.update(grads, params)
	model = eqx.apply_updates(model, updates)
  return model
  • eqx.tree_at mutates a particular leaf or leaves of a pytree
  • equinox.tree_serialise_leaves to save leaves of a pytree to a file

penzai

API is in flux but some useful concepts such as named arrays and treescope for visualising arrays of a pytree. Currently best used for inspecting trained neural networks rather than building and training.

  • Penzai models are registered as jax pytree nodes (similar to Equinox)
    • so any Penzai model can be traversed using jax.tree_util or other jax transformation
    • forward pass of a model, call it like a function
  • Penzai model objects are immutable
    • modifications to Penzai models always occur by making a modified copy of the model
  • Each Penzai model tree has two types of leaves:
    • jax arrays and scalars, which are immutable and often represent hyperparameters, and
    • variable objects, which can be modified:
      • Parameters, usually modified by optimisers (not by the model),
      • StateVariables, usually updated as the model runs.
# each layer has the same signature
# exactly one positional argument – the input from the previous layer
class Layer(pz.Struct, abc.ABC):
	@abc.abstractmethod
	def __call__(self, argument: Any, /, **side_inputs) -> Any:
		...
 
# for example
@struct.pytree_dataclass
class Attention(Layer):
	input_to_query: Layer
	input_to_key: Layer
	input_to_value: Layer
	query_key_to_attn: Layer
	attn_value_to_output: Layer
 
	def __call__(self, x: NamedArray, **side_inputs) -> NamedArray:
		query = self.input_to_query(x, **side_inputs)
		key = self.input_to_key(x, **side_inputs)
		value = self.input_to_value(x, **side_inputs)
		attn = self.query_key_to_attn((query, key), **side_inputs)
		output = self.attn_value_to_output((attn, value), **side_inputs)
		return output
@pz.pytree_dataclass
class LinearRegression(pz.nn.Sequential):
	sublayers: list[pz.nn.Layer]
 
	@classmethod
	def from_config(
		cls,
		name: str,
		init_base_rng: jax.Array | None,
		feature_sizes: Sequence[int],
		features_axis: str = "features",
	) -> "LinearRegression":
	sublayers = []
	for i in range(len(feature_sizes) - 1):
		sublayers.append(Linear.from_config(
			name=f"{name}/block_{i}/linear",
			init_base_rng=init_base_rng,
			in_features=feature_sizes[i],
			out_features=feature_sizes[i + 1],
			features_axis=features_axis,
		))
		sublayers.append(Bias.from_config(
			name=f"{name}/block_{i}/bias",
			init_base_rng=init_base_rng,
			features=feature_sizes[i + 1],
			features_axis=features_axis,
		))
	return cls(sublayers)

example training loop

named axes

  • pz.nx.NamedArray class wraps an ordinary array, and assigns each axis to either a position or a name (but not both)
  • two shapes:
    • .positional_shape attribute is the sequence of dimension sizes for the positional axes in the array
    • .named_shape attribute is a dictionary mapping axis names to their dimension sizes
pz.nx.wrap(jnp.ones([3, 4, 5])).tag("foo", "bar", "baz")
my_named_array.untag("bar", "foo", "baz").unwrap()
  • operations on NamedArrays always act only on the positional axes
    • untag, move those axes to the .positional_shape
    • operate, calling pz.nx.nmap(some_positional_op)(...args...)
    • retag, .tag to move the resulting axes from the .positional_shape back to the .named_shape
def named_dot(x: pz.nx.NamedArray, features_axis: str) -> pz.nx.NamedArray:
	pos_x = x.untag(features_axis)
	pos_kernel = kernel.value.untag("out_features", "in_features")
	pos_y = pz.nx.nmap(jnp.dot)(pos_kernel, pos_x)
	return pos_y.tag(features_axis)