vwxyzjn / cleanba

CleanRL's implementation of DeepMind's Podracer Sebulba Architecture for Distributed DRL
Other
102 stars 11 forks source link

add grad accum for ppo + envpool + impala atari wrapper #8

Closed 51616 closed 1 year ago

51616 commented 1 year ago

Use MultiSteps wrapper from optax to do gradient accumulation to avoid OOM error.

~Note: I'm not sure how to log the learning rate when using the wrapper. I looked through the agent_state.opt_state but didn't find the current learning rate. So I commented it out for now.~ Now the code logs the learning rate correctly.

~Here's the result using two random seeds. I'll run one more run tonight.~ Here's the result with 3 random seeds.

image

vwxyzjn commented 1 year ago

Awesome! Thanks for the PR @51616