ShaneFlandermeyer / tdmpc2-jax

Jax/Flax Implementation of TD-MPC2
33 stars 2 forks source link

DMC tasks #5

Closed edwhu closed 3 weeks ago

edwhu commented 1 month ago
ShaneFlandermeyer commented 1 month ago

LGTM at first glance! I agree that we should follow the original repo's DMC setup as closely as possible while we're doing the performance comparison. Eventually I would like to opt for the cleaner DMC -> gym wrapper to keep things simple to understand and maintain.

I'll run and double check everything this afternoon, then merge if all is well on my end. Thanks a ton! Your contributions so far have been awesome!

edwhu commented 1 month ago

Added some bug fixes for checkpoint loading

ShaneFlandermeyer commented 1 month ago

Based on our training reward plots in issue #1, I think think everything in this PR is working as intended.

Only thing I would change is calling the env.benchmark field env.backend in the config. Otherwise, LGTM! I'll merge if you don't have any other changes you want to add in this PR

edwhu commented 1 month ago

Thanks for the feedback, will incorporate it.

I'm still a bit wary of merging from just 1 task, although Humanoid is probably one of the harder tasks. RL is notoriously finicky.

I think it's worth it to try all 4 tasks Nicklas mentioned before we merge. Do you have some time over the weekend to help out here? We should run ~3 seeds for each task, for 4M steps. If the curves look good, then let's merge.

I can finish the humanoid-stand jobs over the weekend, and possibly one more (finger-turn-hard?)

ShaneFlandermeyer commented 1 month ago

Yeah, happy to help. Just got 4090 for my workstation at the lab, so this is a good chance to break it in! I'll take Cheetah and Walker

We should probably run 1 env, UTD=1 for all of them if we're going for a comparison with the original. Although it's interesting to me that more envs and lower UTD doesn't seem to hurt sample efficiency much at all.

edwhu commented 1 month ago

Great to hear! Yes, 1 env and 1 UTD is good.

We probably don't need to run for 4M for all tasks, like if it's obvious it solves it by 1M or 2M, we can stop the job early. But for Humanoid, I probably should run it for at least 2M steps.

Could you share the plotting code in a gist so I can report my task curves as well?

ShaneFlandermeyer commented 1 month ago

Here you go!

https://gist.github.com/ShaneFlandermeyer/ab2352a65279c7be0a2cfeb1371127fb

ShaneFlandermeyer commented 1 month ago

I'm about to kick off my runs for tonight, but I just pushed a few small changes. I switched to float32 instead of bfloat16 to make things 1:1 with the main repo at the cost of a small speed decrease. I also switched to the built-in jnp.percentile function for actor loss scaling.

I don't expect either of these to affect learning performance but figured you should be aware of the changes before starting your runs.

ShaneFlandermeyer commented 1 month ago

In the walker and cheetah environments, the jax agent seems to lag the original in final reward by ~100. I haven't identified why that is, but just pushed a commit to change some computations to exactly match the original implementation.

edwhu commented 1 month ago

There's also some differences between eval / train for the original TDMPC2 codebase, while this codebase is only doing training. I wonder if we're missing any of these details, would be good to check with Nicklas.

For example, the dropout layers are set to eval mode during evaluation. https://github.com/nicklashansen/tdmpc2/blob/5f6fadec0fec78304b4b53e8171d348b58cac486/tdmpc2/common/world_model.py#L50-L56

ShaneFlandermeyer commented 1 month ago

I don't think it's a train/eval difference, as the original repo got similarly good performance in just its training phase (I tested this today) . Interesting problem...

ShaneFlandermeyer commented 1 month ago

Just found a problem that potentially explains the performance gap: the DMControl environments are infinite horizon, but the TimeStepToGymWrapper uses both the terminated and truncated flag. This doesn't matter in the main branch of the original repo since they don't account for terminal states in the td target or planning, but we account for those things.

Not saying that's the issue, but I think it will make a difference.

ShaneFlandermeyer commented 1 month ago

That was indeed the problem! My initial experimentation with the latest commits improve our performance to match the original repo. Just kicked off a new set of Walker/Cheetah runs for more thorough analysis.

edwhu commented 1 month ago

Great find! I'm a bit surprised since in practice, the distinction between truncation / termination flags don't impact performance too much. Usually they matter more in periodic tasks, which in retrospect, DMC tasks are.

I finished my humanoid runs - they take a decent amount of time, around 15 hours each, so 60 hours total for 4 seeds. I'll plot them.

Then, I can pull the updated code and rerun the humanoid task. If you have spare time, could you run the remaining task (finger turn hard)?

ShaneFlandermeyer commented 1 month ago

Great find! I'm a bit surprised since in practice, the distinction between truncation / termination flags don't impact performance too much. Usually they matter more in periodic tasks, which in retrospect, DMC tasks are.

The problem is that we use the termination flag when computing the TD target, so those values end up weird if termination=truncation.

I finished my humanoid runs - they take a decent amount of time, around 15 hours each, so 60 hours total for 4 seeds. I'll plot them.

Then, I can pull the updated code and rerun the humanoid task. If you have spare time, could you run the remaining task (finger turn hard)?

Yeah I'll run it after my current runs finish. I'm doing 2M steps over 3 seeds, which takes about 5 hours per run.

edwhu commented 1 month ago

tdmpc2_humanoid_stand

Here's the humanoid stand results with the older code. It also seems to lag behind by 100. Will run the updated version.

edwhu commented 1 month ago

And here's the updated plotting code that takes into account the seeds. https://gist.github.com/edwhu/710d5202435831c5c761ea53d63a952e

ShaneFlandermeyer commented 1 month ago

So my laptop lost power and died during the walker run, but here are some good results!

Finger turn: finger_turn_hard

Cheetah run: cheetah_run

ShaneFlandermeyer commented 1 month ago

Any updates on the new humanoid runs? I'm having trouble with walker-run: The reward curve matches the original implementation, but I'm getting an invalid physics state error due to either a nan or inf action. Not sure why this error only happens for this environment if there is a hidden bug in the Jax implementation.

edwhu commented 1 month ago

2/4 seeds done, 17 hours each seed (2 hours slower than before, maybe due to the float dtype). The performance of the 2 seeds are higher than the old code, so that's good.

For the physics error, you could check with Nicklas if he ran into a similar issue with the original codebase - it's possible that if the policy gets really good at running, it can go too fast and violate some physics constraint.

ShaneFlandermeyer commented 1 month ago

He includes a nan_to_num() call in the planner function that I initially didn't think was necessary. I'll add that back and try again. 2M steps of walker-run only takes 4 hours for me, so I'll post an update soonish.

ShaneFlandermeyer commented 1 month ago

Yep, the issue persists after that call. I am also inclined to blame the simulator unless Nicklas has additional insight.

Let me check if the NaNs are coming from the gradient computation or the simulator.

ShaneFlandermeyer commented 1 month ago

Both hypotheses are correct: the NaNs occur in the gradients when the simulator observations are large. Fixed by zero'ing out NaN gradients.

ShaneFlandermeyer commented 1 month ago

I think we can safely say this PR is working properly after the latest push. After my last few runs finish up, I would like to add these results to the readme before merging.

One seed of walker-run (with the other 2 following a similar trajectory so far): walker_run

edwhu commented 1 month ago

Let's wait for the humanoid runs to finish before merging. We can never be too careful with RL.

In the meantime, we can record the current performance profiles like FPS and returns of each env, so we can track any performance regressions in future updates. It might be good to check now to see how much overhead the NaN checks and the float dtype introduces.

Even better is if we have some automated test, but that can be another feature in the future.

edwhu commented 1 month ago

humanoid_stand

Here are the Humanoid Stand results with the new code. It's great that it seems to outperform the original implementation.

ShaneFlandermeyer commented 1 month ago

Looks excellent! Can you re-make the humanoid plot with the formatting below for the readme?

https://gist.github.com/ShaneFlandermeyer/d37ba661d06e7e04cb25a195623ac007

edwhu commented 1 month ago

Yup, will do it after work.

Both hypotheses are correct: the NaNs occur in the gradients when the simulator observations are large. Fixed by zero'ing out NaN gradients.

Hmm, I thought TDMPC2 somewhat accounts for this already through SimNorm.

ShaneFlandermeyer commented 1 month ago

Hmm, I thought TDMPC2 somewhat accounts for this already through SimNorm.

I looked into it deeper yesterday. In my experiments, only the encoder gradients ever went NaN (i.e., before the simnorm). I think we could get away with using zero_nans() in the just the encoder optimizer. I'll do a quick run to verify

ShaneFlandermeyer commented 1 month ago

Nope; it appears that all the zero_nans() are necessary, unfortunately.

edwhu commented 1 month ago

humanoid_stand_results

Here's the humanoid stand plot. Could we add the title of the task to the chart ? So it looks like this: plt.title('Humanoid Stand', fontsize=12, fontweight='bold')

humanoid_stand_results

ShaneFlandermeyer commented 3 weeks ago

Updated readme with new figures! Going to complete the merge now!