danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.28k stars 219 forks source link

Extremely large gradient and vanishing images when using jax.precision=float32 #91

Open ManfredStoiber opened 1 year ago

ManfredStoiber commented 1 year ago

First of all, thank you very much for this impressive work!

When using the configuration jax.precision=float32 for training with images, I always get an extremely large gradient (model_opt_grad_norm at ~6e+8). I assume because of that, the openl image predictions become completely white. When training the dmc_walker_walk task with the dmc_vision configurations by using the train.py-script, the image_loss_mean is at about 7e+7. When using other environments, the image_loss_mean starts at about 2000-5000, but the model_opt_grad_norm stays at ~6e+8.

I'm using float32 because I sporadically get NANs during training with images when using float16.

I already tried changing the lr and clipping values, as well as the image loss scale, but without success.

Am I maybe missing any other configurations I have to change when using float32?

Thank you for your help! Best regards

image

return-sleep commented 10 months ago

May I ask if you have successfully trained the dreamerv3 agent. I'm curious what the final loss of each component looks like,such as image, reward or cont. When the reconstructed images are very similar, I find that the reward's prediction is not as good as it should be. I'm not sure if it also affects subsequent strategy training. Thanks for sharing your thoughts.

ManfredStoiber commented 9 months ago

Unfortunately not, at least not when training on images in the walker environment

danijar commented 5 months ago

Walker always worked for me from images, regardless of precision. I've just updated the paper and code, which has a better optimizer now. Curious if this is still an issue on your end.