fixed point in 1-D

Rewrite into the form and then label left hand side as and right hand side as .

e.g. becomes .

  • converges when
  • generally try and reduce the degree of the polynomial
  • newton’s method is a special case of fixed point iteration where

Oscar Veliz video 1 and 2.

generalised fixed point iteration

Naive forward iteration, where we iterate until  stays sufficiently close to :

def fwd_solver(f, z_init):
	z_prev, z = z_init, f(z_init)
	while jnp.linalg.norm(z_prev - z) > 1e-5:
		z_prev, z = z, f(z)
	return z

gradients

If we use a fixed-point iteration in a model where we need the gradients, such as an implicit layer, naively using automatic differentiation would require the gradients at each intermediate step of the solver.

Instead, we can use the implicit function theorem to define the backward pass only at the solution point.

For the fixed point iterator, the Jacobian becomes (as we now have instead of )

Jacobian-vector product

We want to avoid the matrix inversion by using the adjoint method.

  1. Calculate using jax.jvp.
  2. The final solution is now . So and we can use a fixed point iterator (or other solver) to solve the linear system for .

vector-Jacobian product

  1. Calculate the adjoint as .
  2. Compute using jax.vjp.

The jax code for the VJP:

from functools import partial
 
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def fixed_point_layer(solver, f, params, x):
	z_star = solver(lambda z: f(params, x, z), z_init=jnp.zeros_like(x))
	return z_star
 
def fixed_point_layer_fwd(solver, f, params, x):
	z_star = fixed_point_layer(solver, f, params, x)
	return z_star, (params, x, z_star)
 
def fixed_point_layer_bwd(solver, f, res, z_star_bar):
	params, x, z_star = res
	_, vjp_a = jax.vjp(lambda params, x: f(params, x, z_star), params, x)
	_, vjp_z = jax.vjp(lambda z: f(params, x, z), z_star)
	return vjp_a(solver(lambda u: vjp_z(u)[0] + z_star_bar, z_init=jnp.zeros_like(z_star)))
 
fixed_point_layer.defvjp(fixed_point_layer_fwd, fixed_point_layer_bwd)