kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

Non-Deterministic Output #113

Closed jerrygreen closed 3 years ago

jerrygreen commented 3 years ago

Is it a solvable problem? How to make this model be deterministic?

Ideally, I'd want to be able to pass a seed, like for noise-based RNG functions, so I can use different seeds to control that.

kingoflolz commented 3 years ago

you just need to modify the function to take in a seed rather than generating one https://github.com/kingoflolz/mesh-transformer-jax/blob/1dfdb06eed54a7d23298e43c06b266a3a45d4ab6/mesh_transformer/transformer_shard.py#L327