defines a system of nonlinear equations on , parameterised by .
Take a solution mapping function (e.g. solving fixed point equation with root ), which satisfies
Differentiate both sides with respect to and evaluate at the point .
And then rearrange for the Jacobian
- The Jacobian of the solution mapping can be expressed just in terms of Jacobians of at the solution point (regardless of how we get to the solution).
- We can skip propagating derivatives through the solver using automatic differentiation systems:
- In PyTorch, use
torch.no_grad()
when running the solver and add the backward pass withz.register_hook(lambda grad : torch.solve(grad[:,:,None], J.transpose(1,2))[0][:,:,0])
. - In jax, use
jax.custom_jvp
andjax.custom_vjp
.
- In PyTorch, use