For sequential data generation, the official way to split keys.
for i in range(N):
key, subkey = jax.random.split(key)
jax.random.normal(subkey)
Or we can build a generator.
# with a limit to avoid infinite `for` loop
def key_gen(seed, N=10):
key = jax.random.PRNGKey(seed)
i = 0
while i < N:
yield key
key, _ = jax.random.split(key)
i += 1
key_iter = key_gen(10, N=3)
for key in key_iter:
print(jax.random.normal(key))
# or define without limit and use `next()`
def key_gen(seed):
key = jax.random.PRNGKey(seed)
while True:
yield key
key, _ = jax.random.split(key)
keys = key_gen(10)
print(jax.random.normal(next(keys)))
print(jax.random.normal(next(keys)))
For non-sequential data, it is best to vectorise using jax.vmap.
def f(key):
return jax.random.normal(key)
keys = jax.random.split(key, N)
jax.vmap(f)(keys)