jax.vmap is a functional transform for vectorised mapping. The operation vmap(func)(batched_array) is the equivalent of the pure-python loop

def func(x):
	...
	return result
 
np.stack([func(array) for array in batched_array])

By default, it vectorises over the leading (batch) dimension (in_axes=0). This argument can be used to change the dimension of vectorisation. For example, a summing function sum_vector can act over rows of a matrix, rowsum_func = vmap(sum_vector, in_axes=0), or the columns, colsum_func = vmap(sum_vector, in_axes=1).

Further examples including converting rows of a matrix into a stack of probability vectors and parallelising MCMC inference when sampling multiple chains.

jax transformations can work transparently with either arrays or pytrees of arrays.

The developer documentation explains how vmap is implemented.