Adapted from Ravin Kumar’s GenAI Guidebook.
linear regression in numpy
Linear regression (1-D case)
where and have shape (N_features,)
.
Vectorising to operate on the entire at once, in (jax-)numpy we write:
y = (x @ coefficients.T) + bias
y = jnp.einsum('ij,kj->ki', coefficients, x) + bias
using using einsum notation
linear regression in flax
In neural network notation, we call the regression coefficients the “weights”, , and the intercept the bias, , giving us
This is a dense neural network with one layer and no activation function.
import flax.linen as nn
class LinearRegression(nn.Module):
def setup(self):
self.dense = nn.Dense(features=1)
# or use `@nn.compact` to do this inline
# Define the forward pass
def __call__(self, x):
y_pred = self.dense(x)
return y_pred
model = LinearRegression()
y_pred = model.apply(params, x)
Use optax
in a training loop to tune params
via gradient descent.
import optax
from flax.training import train_state # dataclass to keep train state
key = jax.random.PRNGKey(0)
params = model.init(key, x_obs)
@jax.jit
def flax_l2_loss(params, x, y_true):
y_pred = model.apply(params, x)
total_loss = optax.l2_loss(y_pred, y_true).sum()
return total_loss
optimizer = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(apply_fn=model, params=params, tx=optimizer)
_loss = []
for epoch in range(10000):
# Calculate the gradient
loss, grads = jax.value_and_grad(flax_l2_loss)(state.params, x_obs, y_obs_noisy)
_loss.append(loss)
# Update the model parameters
state = state.apply_gradients(grads=grads)
Linear regression cannot deal with nonlinear data. We can introduce nonlinearity by adding a second layer and an activation function.
Then just scale up and train for longer.