import jax
import jax.numpy as jnp
pytree = [
[1, 'x', object()],
{'a': 1},
{
'b': 2,
'c': (3, 4),
'd': None,
},
jnp.array([1, 2, 3]),
]
leaves, treedef = jax.tree.flatten(pytree)
print(f"leaves = {leaves}")
print(f"treedef = {treedef}")leaves = [1, 'x', <object object at 0x1196ab980>, 1, 2, 3, 4, Array([1, 2, 3], dtype=int32)]
treedef = PyTreeDef([[*, *, *], {'a': *}, {'b': *, 'c': (*, *), 'd': None}, *])Noneis not a leaf because its defined as a pytree with no children.- Any object whose type is not in the pytree container registry will be treated as a leaf node in the tree.
- This can be changed by registering custom pytree node.
- Dataclasses are not automatically pytrees but they can be using
jax.tree_util.register_dataclass
- Arrays, which are leaves, in jax are as immutable.
NamedTupleis often used to store params.
jax.tree.map(f: Callable[..., Any], tree: Any) maps a function over a tree’s leaves.
equinox
The equinox module makes the class both a dataclass and a pytree. tinygp uses equinox.Module to simplify dataclass so we don’t need to use jax.tree_util.register_pytree_node.
- class-like behaviour can be obtained by avoiding extra abstractions and simply exploiting pytrees
Equinox adds extra tree methods:
eqx.tree_at mutatesa particular leaf or leaves of a pytreeequinox.tree_serialise_leavesto save leaves of a pytree to a file
Equinox supports arbitrary python objects for its leaves (e.g. activation function). Need to use filtering (@eqx.filter_jit, @eqx.filter_grad etc) to operate on:
params, the leaves that are arrays; and notstatic, everything else that cannot be operated on
eqx.is_array is a boolean filter function specifying whether each leaf should go into params or into static by returning True for jax and numpy arrays, and False for everything else.