kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

Generating random numbers – None PRNGKey error #173

Closed cifkao closed 2 years ago

cifkao commented 2 years ago

I tried modifying the model in a way that requires generating random numbers inside the Transformer layer. Specifically, I added a call to hk.next_rng_key() to TransformerLayerShard.__call__ so that I can have a different random number for each batch. This results in the following error during training:

  ...
  File "/home/ocifka/mesh-transformer-jax/mesh_transformer/layers.py", line 312, in __call__
    key = hk.next_rng_key()
  File "/home/ocifka/.local/lib/python3.8/site-packages/haiku/_src/base.py", line 638, in next_rng_key
    return next_rng_key_internal()
  File "/home/ocifka/.local/lib/python3.8/site-packages/haiku/_src/base.py", line 643, in next_rng_key_internal
    rng_seq = rng_seq_or_fail()
  File "/home/ocifka/.local/lib/python3.8/site-packages/haiku/_src/base.py", line 599, in rng_seq_or_fail
    raise ValueError("You must pass a non-None PRNGKey to init and/or apply "
ValueError: You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers.

As far as I can tell, this is due to a None PRNGKey being passed by default due to this line: https://github.com/kingoflolz/mesh-transformer-jax/blob/e84de89465b8526784ab77804864fc3d22fde28c/mesh_transformer/transformer_shard.py#L141

I would appreciate any advice on where I should pass my PRNGKey in order to get a different random number for each training batch.

kingoflolz commented 2 years ago

You can remove the call to without_apply_rng, but you need to feed in a rngkey every time you call train_loss_fn