AakashKumarNain / TF_JAX_tutorials

All about the fundamental blocks of TF and JAX!
MIT License
271 stars 22 forks source link

Incorrect usage of PRNG #4

Open pgenevski opened 1 year ago

pgenevski commented 1 year ago

Hi,

I noticed that you are reusing the same key in e.g. cell 23 of main/src/notebooks/jax_tutorials/chapter_5_vmap_pmap.ipynb

key, subkey = random.split(key)
rotate = random.randint(key, shape=[batch_size], minval=0, maxval=2)

Looks like you shall be using the subkey in random.randint, not the original key. The way it is now subkey is never used.

pgenevski commented 1 year ago

Another example is cell 24, where you are practically reusing the same key in the for loop, by doing this several times:

key=random.PRNGKey(0)

It would be better if you seed the PRNG once and split the key in every iteration as in :

key = random.PRNGKey(42)

for _ in range(3):
    key, subkey = random.split(key)
    a = random.normal(key=subkey)
    print(a)

Output:

1.3694694
-0.19947024
-2.2982783