Use MultiSteps wrapper from optax to do gradient accumulation to avoid OOM error.
~Note: I'm not sure how to log the learning rate when using the wrapper. I looked through the agent_state.opt_state but didn't find the current learning rate. So I commented it out for now.~
Now the code logs the learning rate correctly.
~Here's the result using two random seeds. I'll run one more run tonight.~
Here's the result with 3 random seeds.
Use
MultiSteps
wrapper fromoptax
to do gradient accumulation to avoid OOM error.~Note: I'm not sure how to log the learning rate when using the wrapper. I looked through the
agent_state.opt_state
but didn't find the current learning rate. So I commented it out for now.~ Now the code logs the learning rate correctly.~Here's the result using two random seeds. I'll run one more run tonight.~ Here's the result with 3 random seeds.