Adapted from Ravin Kumar’s GenAI Guidebook.
linear regression in numpy
Linear regression (1-D case)
y=β⋅x+c
where x and β have shape (N_features,)
.
Vectorising to operate on the entire y at once, in (jax-)numpy we write:
y = (x @ coefficients.T) + bias
y = jnp.einsum('ij,kj->ki', coefficients, x) + bias
using using einsum notation
linear regression in flax
In neural network notation, we call the regression coefficients the “weights”, w, and the intercept the bias, b, giving us
F(x)=wTx+b
This is a dense neural network with one layer and no activation function.
Use optax
in a training loop to tune params
via gradient descent.
Linear regression cannot deal with nonlinear data. We can introduce nonlinearity by adding a second layer and an activation function.
Then just scale up and train for longer.