jax.experimental.checkify allows jit-able runtime error checking.

GPJax use checks to make sure the lengthscale parameters are non-negative.

@checkify.checkify
def _check_is_non_negative(value):
    checkify.check(
        jnp.all(value >= 0), "value needs to be non-negative, got {value}", value=value
    )

This means at runtime, when this particular part of the function jit-compiled, we need to also wrap with jax.experimental.checkify so that the assertion is passed and compilation is not broken.

jit_compute_gram = jax.jit(jax.experimental.checkify(compute_gram))
error, value = jit_compute_gram(1.0)