pz.nx.NamedArrayclass wraps an ordinary array, and assigns each axis to either a position or a name (but not both)- two shapes:
.positional_shapeattribute is the sequence of dimension sizes for the positional axes in the array.named_shapeattribute is a dictionary mapping axis names to their dimension sizes
pz.nx.wrap(jnp.ones([3, 4, 5])).tag("foo", "bar", "baz")
my_named_array.untag("bar", "foo", "baz").unwrap()- operations on NamedArrays always act only on the positional axes
- untag, move those axes to the
.positional_shape - operate, calling
pz.nx.nmap(some_positional_op)(...args...) - retag,
.tagto move the resulting axes from the.positional_shapeback to the.named_shape
- untag, move those axes to the
def named_dot(x: pz.nx.NamedArray, features_axis: str) -> pz.nx.NamedArray:
pos_x = x.untag(features_axis)
pos_kernel = kernel.value.untag("out_features", "in_features")
pos_y = pz.nx.nmap(jnp.dot)(pos_kernel, pos_x)
return pos_y.tag(features_axis)Note, for pmap, there is also an axis_name parameter.