For sequential data generation, the official way to split keys. for i in range(N): key, subkey = jax.random.split(key) jax.random.normal(subkey) Or we can build a generator. # with a limit to avoid infinite `for` loop def key_gen(seed, N=10): key = jax.random.PRNGKey(seed) i = 0 while i < N: yield key key, _ = jax.random.split(key) i += 1 key_iter = key_gen(10, N=3) for key in key_iter: print(jax.random.normal(key)) # or define without limit and use `next()` def key_gen(seed): key = jax.random.PRNGKey(seed) while True: yield key key, _ = jax.random.split(key) keys = key_gen(10) print(jax.random.normal(next(keys))) print(jax.random.normal(next(keys))) For non-sequential data, it is best to vectorise using jax.vmap. def f(key): return jax.random.normal(key) keys = jax.random.split(key, N) jax.vmap(f)(keys)