pz.nx.NamedArray class wraps an ordinary array, and assigns each axis to either a position or a name (but not both)
- two shapes:
.positional_shape attribute is the sequence of dimension sizes for the positional axes in the array
.named_shape attribute is a dictionary mapping axis names to their dimension sizes
pz.nx.wrap(jnp.ones([3, 4, 5])).tag("foo", "bar", "baz")
my_named_array.untag("bar", "foo", "baz").unwrap()
- operations on NamedArrays always act only on the positional axes
- untag, move those axes to the
.positional_shape
- operate, calling
pz.nx.nmap(some_positional_op)(...args...)
- retag,
.tag to move the resulting axes from the .positional_shape back to the .named_shape
def named_dot(x: pz.nx.NamedArray, features_axis: str) -> pz.nx.NamedArray:
pos_x = x.untag(features_axis)
pos_kernel = kernel.value.untag("out_features", "in_features")
pos_y = pz.nx.nmap(jnp.dot)(pos_kernel, pos_x)
return pos_y.tag(features_axis)