google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.33k stars 255 forks source link

Brax Halfcheetah Exploding Gradients #186

Open siwei0729 opened 2 years ago

siwei0729 commented 2 years ago

Hi,

I'm working on a project that uses differentiable dynamics. However, for the task halfcheetah, I'm having problems with the gradient explosion. I have created a repo to reproduce this problem.

The problem can be reproduced by using the official implementation analytical policy gradient apg.py with official reward function. The only thing I changed is to print out the gradient norm before clipping.

To reproduce

python apg.py

Environment

python                    3.8

brax                      0.0.12
jax                       0.3.5                    
jaxlib                    0.3.5+cuda11.cudnn82    

nvidia-smi

NVIDIA-SMI 510.54       Driver Version: 510.54       CUDA Version: 11.6`

Gradient norm from Halfcheetah

grad_raw [inf]
grad_raw [inf]
grad_raw [inf]
grad_raw [3.7764926e+18]
grad_raw [inf]
grad_raw [inf]

Gradient norm from ant

grad_raw [1.7340995]
grad_raw [2.4045153]
grad_raw [2.8107145]
grad_raw [1.8724597]
grad_raw [3.0794723]
grad_raw [2.4992204]
cdfreeman-google commented 2 years ago

Could you plot the gradient norm as a function of trajectory length for, say, a random policy? We've seen this kind of thing before, and it usually reduces to chaotic/unstable dynamics in the system, so you may need to introduce a truncation length to get stable behavior.