danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.18k stars 203 forks source link

Pong results do not match paper #138

Open George614 opened 1 month ago

George614 commented 1 month ago

Hi Danijar,

Thanks for sharing this amazing repo and creating a robust model-based RL algorithm! I've been playing with the replay buffer and trying to reproduce some of the results. I run the code on Pong with command python dreamerv3/main.py --logdir ./logdir/uniform_pong --configs atari --task atari_pong --run.train_ratio 32 with the default configurations on a Ubuntu 22.04 LTS with a RTX 3090 GPU. Somehow, the agent does not work on the Pong task over 400K env steps (according to the first version of the paper). I'm not sure what went wrong. I've tried with the default uniform replay (cyan curve in figure), a mixed replay (gray curve) with ratio of (0.5, 0.3, 0.2) and uniform replay with compute_dtype: float16 (magenta curve) since I've seen some warnings from CUDA and XLA. Screenshot 2024-06-04 203650

Here are the package versions that I installed:

python 3.11.9 jax 0.4.28
jax-cuda12-pjrt 0.4.28
jax-cuda12-plugin 0.4.28
jaxlib 0.4.28+cuda12.cudnn89 ale-py 0.8.1 gymnasium 0.29.1 tensorflow-cpu 2.16.1
tensorflow-probability 0.24.0

Please let me know if anything was not set up properly. Thank you!

IcarusWizard commented 1 month ago

As far as I know, the run.train_ratio should be 1024 for Atari100k.

NonsansWD commented 1 month ago

Hey, First of all i think the first comment on this is right, you should increase the train_ratio. That was confusing for me too at first but that should solve the issue. Quick off topic question tho: I see you are running pretty recent versions of tensorflow-cpu as well as jax. Did u run into any issues where the pip installation stated that jax requires mldtype >= 4.0 and tensorflow requires that library to be version 3.2?

George614 commented 1 month ago

As far as I know, the run.train_ratio should be 1024 for Atari100k.

Thanks @IcarusWizard and @NonsansWD I'll try your suggestion!

George614 commented 1 month ago

Hey, First of all i think the first comment on this is right, you should increase the train_ratio. That was confusing for me too at first but that should solve the issue. Quick off topic question tho: I see you are running pretty recent versions of tensorflow-cpu as well as jax. Did u run into any issues where the pip installation stated that jax requires mldtype >= 4.0 and tensorflow requires that library to be version 3.2?

I have not run into that particular issue. I'd suggest that you install tensorflow-cpu first (maybe a less recent version) then install JAX.

NonsansWD commented 1 month ago

Hey, First of all i think the first comment on this is right, you should increase the train_ratio. That was confusing for me too at first but that should solve the issue. Quick off topic question tho: I see you are running pretty recent versions of tensorflow-cpu as well as jax. Did u run into any issues where the pip installation stated that jax requires mldtype >= 4.0 and tensorflow requires that library to be version 3.2?

I have not run into that particular issue. I'd suggest that you install tensorflow-cpu first (maybe a less recent version) then install JAX.

Alright good to know. In the end i was able to fix my issue and everything works fine. The only problem im left with is i just realized that the resulting folder called "replay" does not contain raw frames but instead a lot of data like rewards and so on. Do you by any chance know a way of obtaining a video of the agents steps or something so i can watch it do its stuff without too much effort? I feel like im missing something cause i also dont know where to get these wonderful score plots or do i have to construct that plot myself with matplotlib? sorry for going off topic