For sequential data generation, the official way to split keys.

for i in range(N):
	key, subkey = jax.random.split(key)
	jax.random.normal(subkey)

Or we can build a generator.

# with a limit to avoid infinite `for` loop
def key_gen(seed, N=10):
	key = jax.random.PRNGKey(seed)
	i = 0
	while i < N:
		yield key
		key, _ = jax.random.split(key)
		i += 1
 
key_iter = key_gen(10, N=3)
for key in key_iter:
	print(jax.random.normal(key))
 
# or define without limit and use `next()`
def key_gen(seed):
	key = jax.random.PRNGKey(seed)
	while True:
		yield key
		key, _ = jax.random.split(key)
 
keys = key_gen(10)
 
print(jax.random.normal(next(keys)))
print(jax.random.normal(next(keys)))

For non-sequential data, it is best to vectorise using jax.vmap.

def f(key):
	return jax.random.normal(key)
 
keys = jax.random.split(key, N)
jax.vmap(f)(keys)