Open pgenevski opened 1 year ago
Another example is cell 24, where you are practically reusing the same key in the for loop, by doing this several times:
key=random.PRNGKey(0)
It would be better if you seed the PRNG once and split the key in every iteration as in :
key = random.PRNGKey(42)
for _ in range(3):
key, subkey = random.split(key)
a = random.normal(key=subkey)
print(a)
Output:
1.3694694
-0.19947024
-2.2982783
Hi,
I noticed that you are reusing the same key in e.g. cell 23 of
main/src/notebooks/jax_tutorials/chapter_5_vmap_pmap.ipynb
Looks like you shall be using the subkey in
random.randint
, not the original key. The way it is now subkey is never used.