Adapted from Ravin Kumar’s GenAI Guidebook.

linear regression in numpy

Linear regression (1-D case)

where and have shape (N_features,).

Vectorising to operate on the entire at once, in (jax-)numpy we write:

  • y = (x @ coefficients.T) + bias
  • y = jnp.einsum('ij,kj->ki', coefficients, x) + bias using using einsum notation

linear regression in flax

In neural network notation, we call the regression coefficients the “weights”, , and the intercept the bias, , giving us

This is a dense neural network with one layer and no activation function.

import flax.linen as nn
 
class LinearRegression(nn.Module):
    def setup(self):
        self.dense = nn.Dense(features=1)
 
	# or use `@nn.compact` to do this inline
    # Define the forward pass
    def __call__(self, x):
        y_pred = self.dense(x)
        return y_pred
 
model = LinearRegression()
y_pred = model.apply(params, x)

Use optax in a training loop to tune params via gradient descent.

import optax
from flax.training import train_state  # dataclass to keep train state
 
key = jax.random.PRNGKey(0)
params = model.init(key, x_obs)
 
@jax.jit
def flax_l2_loss(params, x, y_true):
	y_pred = model.apply(params, x)
	total_loss = optax.l2_loss(y_pred, y_true).sum()
	return total_loss
 
optimizer = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(apply_fn=model, params=params, tx=optimizer)
 
_loss = []
for epoch in range(10000):
	# Calculate the gradient
	loss, grads = jax.value_and_grad(flax_l2_loss)(state.params, x_obs, y_obs_noisy)
	_loss.append(loss)
	# Update the model parameters
	state = state.apply_gradients(grads=grads)

Linear regression cannot deal with nonlinear data. We can introduce nonlinearity by adding a second layer and an activation function.

Then just scale up and train for longer.