tracer

JitTracers are abstract stand-ins for array objects. A tracer is an abstract representation of every possible value with a given its essential attributes (shape, dtype etc), while a numpy array is a concrete member of that abstract class.

jaxpr

The jaxpr of a jax.jit compiled function is a sequence of primitive operations defining the functional program jax.marke_jaxpr(func).

compiling

jit compiling a function creates JitTracer. The aim when designing a jax program is to minimise how often we need to recompile a function. Once a function is compiled, it will be fast for a tracer that fits, but if we need to recompile to create more-abstract tracer, the program becomes slow.

Avoid calling jax.jit within inner loops, because the program will spend more time compiling.

jax transformations are designed to understand the side-effect-free (functionally pure) code. i.e. do not affect any global state.

Use partial to define static arguments, so the jaxpr recompiles every new value of the static input, but gets a less-abstract tracer. Only use static arguments if they rarely change, otherwise they would be recompiled for each static argument.

from functools import partial
 
@partial(jax.jit, static_argnames=["n"])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i
 
print(g_jit_decorated(10, 20))

During control flow operations, jax will trace through all branches and loops of the program. To avoid this static unrolling use the control flow primitives jax.lax.cond, jax.lax.scan, jax.lax.while_loop, jax.lax.fori_loop etc