jax.vmap
is a functional transform for vectorised mapping. The operation vmap(func)(batched_array)
is the equivalent of the pure-python loop
By default, it vectorises over the leading (batch) dimension (in_axes=0
). This argument can be used to change the dimension of vectorisation. For example, a summing function sum_vector
can act over rows of a matrix, rowsum_func = vmap(sum_vector, in_axes=0)
, or the columns, colsum_func = vmap(sum_vector, in_axes=1)
.
Further examples including converting rows of a matrix into a stack of probability vectors and parallelising MCMC inference when sampling multiple chains.
jax transformations can work transparently with either arrays or pytrees of arrays.
The developer documentation explains how vmap
is implemented.