Take a solution mapping function
Differentiate both sides with respect to
And then rearrange for the Jacobian
- The Jacobian of the solution mapping can be expressed just in terms of Jacobians of
at the solution point (regardless of how we get to the solution). - We can skip propagating derivatives through the solver using automatic differentiation systems:
- In PyTorch, use
torch.no_grad()
when running the solver and add the backward pass withz.register_hook(lambda grad : torch.solve(grad[:,:,None], J.transpose(1,2))[0][:,:,0])
. - In jax, use
jax.custom_jvp
andjax.custom_vjp
.
- In PyTorch, use