defines a system of nonlinear equations on , parameterised by .

Take a solution mapping function  (e.g. solving fixed point equation with root ), which satisfies

Differentiate both sides with respect to  and evaluate at the point .

And then rearrange for the Jacobian

  • The Jacobian of the solution mapping can be expressed just in terms of Jacobians of  at the solution point (regardless of how we get to the solution).
  • We can skip propagating derivatives through the solver using automatic differentiation systems:
    • In PyTorch, use torch.no_grad() when running the solver and add the backward pass with z.register_hook(lambda grad : torch.solve(grad[:,:,None], J.transpose(1,2))[0][:,:,0]).
    • In jax, use jax.custom_jvp and jax.custom_vjp.