Closed edwhu closed 3 weeks ago
LGTM at first glance! I agree that we should follow the original repo's DMC setup as closely as possible while we're doing the performance comparison. Eventually I would like to opt for the cleaner DMC -> gym wrapper to keep things simple to understand and maintain.
I'll run and double check everything this afternoon, then merge if all is well on my end. Thanks a ton! Your contributions so far have been awesome!
Added some bug fixes for checkpoint loading
Based on our training reward plots in issue #1, I think think everything in this PR is working as intended.
Only thing I would change is calling the env.benchmark
field env.backend
in the config. Otherwise, LGTM! I'll merge if you don't have any other changes you want to add in this PR
Thanks for the feedback, will incorporate it.
I'm still a bit wary of merging from just 1 task, although Humanoid is probably one of the harder tasks. RL is notoriously finicky.
I think it's worth it to try all 4 tasks Nicklas mentioned before we merge. Do you have some time over the weekend to help out here? We should run ~3 seeds for each task, for 4M steps. If the curves look good, then let's merge.
I can finish the humanoid-stand jobs over the weekend, and possibly one more (finger-turn-hard?)
Yeah, happy to help. Just got 4090 for my workstation at the lab, so this is a good chance to break it in! I'll take Cheetah and Walker
We should probably run 1 env, UTD=1 for all of them if we're going for a comparison with the original. Although it's interesting to me that more envs and lower UTD doesn't seem to hurt sample efficiency much at all.
Great to hear! Yes, 1 env and 1 UTD is good.
We probably don't need to run for 4M for all tasks, like if it's obvious it solves it by 1M or 2M, we can stop the job early. But for Humanoid, I probably should run it for at least 2M steps.
Could you share the plotting code in a gist so I can report my task curves as well?
I'm about to kick off my runs for tonight, but I just pushed a few small changes. I switched to float32 instead of bfloat16 to make things 1:1 with the main repo at the cost of a small speed decrease. I also switched to the built-in jnp.percentile function for actor loss scaling.
I don't expect either of these to affect learning performance but figured you should be aware of the changes before starting your runs.
In the walker and cheetah environments, the jax agent seems to lag the original in final reward by ~100. I haven't identified why that is, but just pushed a commit to change some computations to exactly match the original implementation.
There's also some differences between eval / train for the original TDMPC2 codebase, while this codebase is only doing training. I wonder if we're missing any of these details, would be good to check with Nicklas.
For example, the dropout layers are set to eval mode during evaluation. https://github.com/nicklashansen/tdmpc2/blob/5f6fadec0fec78304b4b53e8171d348b58cac486/tdmpc2/common/world_model.py#L50-L56
I don't think it's a train/eval difference, as the original repo got similarly good performance in just its training phase (I tested this today) . Interesting problem...
Just found a problem that potentially explains the performance gap: the DMControl environments are infinite horizon, but the TimeStepToGymWrapper
uses both the terminated and truncated flag. This doesn't matter in the main branch of the original repo since they don't account for terminal states in the td target or planning, but we account for those things.
Not saying that's the issue, but I think it will make a difference.
That was indeed the problem! My initial experimentation with the latest commits improve our performance to match the original repo. Just kicked off a new set of Walker/Cheetah runs for more thorough analysis.
Great find! I'm a bit surprised since in practice, the distinction between truncation / termination flags don't impact performance too much. Usually they matter more in periodic tasks, which in retrospect, DMC tasks are.
I finished my humanoid runs - they take a decent amount of time, around 15 hours each, so 60 hours total for 4 seeds. I'll plot them.
Then, I can pull the updated code and rerun the humanoid task. If you have spare time, could you run the remaining task (finger turn hard)?
Great find! I'm a bit surprised since in practice, the distinction between truncation / termination flags don't impact performance too much. Usually they matter more in periodic tasks, which in retrospect, DMC tasks are.
The problem is that we use the termination flag when computing the TD target, so those values end up weird if termination=truncation.
I finished my humanoid runs - they take a decent amount of time, around 15 hours each, so 60 hours total for 4 seeds. I'll plot them.
Then, I can pull the updated code and rerun the humanoid task. If you have spare time, could you run the remaining task (finger turn hard)?
Yeah I'll run it after my current runs finish. I'm doing 2M steps over 3 seeds, which takes about 5 hours per run.
Here's the humanoid stand results with the older code. It also seems to lag behind by 100. Will run the updated version.
And here's the updated plotting code that takes into account the seeds. https://gist.github.com/edwhu/710d5202435831c5c761ea53d63a952e
So my laptop lost power and died during the walker run, but here are some good results!
Finger turn:
Cheetah run:
Any updates on the new humanoid runs? I'm having trouble with walker-run: The reward curve matches the original implementation, but I'm getting an invalid physics state error due to either a nan or inf action. Not sure why this error only happens for this environment if there is a hidden bug in the Jax implementation.
2/4 seeds done, 17 hours each seed (2 hours slower than before, maybe due to the float dtype). The performance of the 2 seeds are higher than the old code, so that's good.
For the physics error, you could check with Nicklas if he ran into a similar issue with the original codebase - it's possible that if the policy gets really good at running, it can go too fast and violate some physics constraint.
He includes a nan_to_num() call in the planner function that I initially didn't think was necessary. I'll add that back and try again. 2M steps of walker-run only takes 4 hours for me, so I'll post an update soonish.
Yep, the issue persists after that call. I am also inclined to blame the simulator unless Nicklas has additional insight.
Let me check if the NaNs are coming from the gradient computation or the simulator.
Both hypotheses are correct: the NaNs occur in the gradients when the simulator observations are large. Fixed by zero'ing out NaN gradients.
I think we can safely say this PR is working properly after the latest push. After my last few runs finish up, I would like to add these results to the readme before merging.
One seed of walker-run (with the other 2 following a similar trajectory so far):
Let's wait for the humanoid runs to finish before merging. We can never be too careful with RL.
In the meantime, we can record the current performance profiles like FPS and returns of each env, so we can track any performance regressions in future updates. It might be good to check now to see how much overhead the NaN checks and the float dtype introduces.
Even better is if we have some automated test, but that can be another feature in the future.
Here are the Humanoid Stand results with the new code. It's great that it seems to outperform the original implementation.
Looks excellent! Can you re-make the humanoid plot with the formatting below for the readme?
https://gist.github.com/ShaneFlandermeyer/d37ba661d06e7e04cb25a195623ac007
Yup, will do it after work.
Both hypotheses are correct: the NaNs occur in the gradients when the simulator observations are large. Fixed by zero'ing out NaN gradients.
Hmm, I thought TDMPC2 somewhat accounts for this already through SimNorm.
Hmm, I thought TDMPC2 somewhat accounts for this already through SimNorm.
I looked into it deeper yesterday. In my experiments, only the encoder gradients ever went NaN (i.e., before the simnorm). I think we could get away with using zero_nans()
in the just the encoder optimizer. I'll do a quick run to verify
Nope; it appears that all the zero_nans()
are necessary, unfortunately.
Here's the humanoid stand plot. Could we add the title of the task to the chart ? So it looks like this:
plt.title('Humanoid Stand', fontsize=12, fontweight='bold')
Updated readme with new figures! Going to complete the merge now!
Support DMC environments. The dmc env loading code is largely lifted from the TDMPC2 codebase.
Small improvements: