Closed Chulabhaya closed 4 months ago
Interesting! I can take a closer look next week. But first, can you check if the weights are significantly different? I want to rule out asynchronous wandb callbacks.
Interesting! I can take a closer look next week. But first, can you check if the weights are significantly different? I want to rule out asynchronous wandb callbacks.
Hi @luchris429! I appreciate your prompt reply, and thanks again for all your awesome work with this repo. To answer your question, I looked at the weights of the first dense layer for another set of runs and they are in fact different at the point I checked (~50k timesteps in with CartPole-v1). The differences aren't drastic, but comparing weights across runs that are the same, the weights are identical so it's clearly the differing weights leading to the differing results.
What is interesting however is that if you look at both these plots and the plots I posted above, you can see that all runs are identical up to a certain number of timesteps before some diverge while others stay the same.
Here's a plot of the runs, you can see how two have identical returns while the third is different:
Now here's snapshots of part of the weights:
Weight snapshot for two runs that are the same:
Weight snapshot for run that is different:
Mystery solved! Turns out the issue was related to non-determinism that exists by default when training with a GPU. I verified this by training with just CPU for 10 runs, where the results are identical:
I then did GPU training with the following line setting this environment variable: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
(ref: https://github.com/google/jax/discussions/10674)
10 runs of this are also completely identical:
If I remove this flag, then the GPU training becomes non-deterministic pretty quick:
Hey all! I'm trying to track down a seeming reproducibility issue I'm having with the PPO implementation after I added some simple WandB logging. I ran the same code 7 times, and 4 of the times the results are identical. However, 3 of the times the results differ:
Would anyone have any ideas as to why this might be happening?
Here's my slightly modified code: