ShaneFlandermeyer / tdmpc2-jax

Jax/Flax Implementation of TD-MPC2
43 stars 6 forks source link

Nice work! Reproducing original results? #1

Closed nicklashansen closed 2 months ago

nicklashansen commented 6 months ago

Hi Shane,

Great work on this. I'll definitely have a closer look soon. Do you have any preliminary benchmark results that compare the jax and pytorch versions? Both in terms of env steps and wall-time. Super curious to see how they compare!

ShaneFlandermeyer commented 6 months ago

Benchmarks and improved logging are next up on my to-do list. Unfortunately, I have a conference deadline coming up and will have to grind on that for a couple weeks instead.

For now, I'll be a bit more qualitative: I've tested and verified that both implementations perform similarly in the HalfCheetah, Walker, and Humanoid mujoco environments. In these envs, the Jax agent's throughput is 80-90 FPS, while the Torch implementation sits at 15-20.

nicklashansen commented 6 months ago

Awesome, thanks for the update and best of luck with your deadline!

edwhu commented 4 months ago

@ShaneFlandermeyer did you ever get a chance to reproduce more results? Happy to help out here.

ShaneFlandermeyer commented 4 months ago

@edwhu Unfortunately I have not. These last couple months have been really busy...

I would greatly appreciate help with this if you're up for it! I recommend doing the comparison with the develop branch, as there are some small changes I haven't merged back to main. Let me know if there's anything I can help with!

edwhu commented 4 months ago

I'll take a crack at it.

edwhu commented 4 months ago

Would be nice to keep a globalized list of feature requests somewhere - it seems like we would like:

ShaneFlandermeyer commented 4 months ago

Yeah, that just about covers my ideas for a development road map. Logging support and checkpointing are the best places to start imo. I'll add a "contributing" section to the readme with things that can be improved!

nicklashansen commented 4 months ago

I don't think I can contribute much code-wise (I don't speak jax) but happy to help out with more conceptual things (if any)!

edwhu commented 4 months ago

Are there any good prototypical or debugging tasks you can suggest for comparing performance between the implementations?

A task that's too easy will let buggy implementations go through.

Really hard tasks that TDMPC2 excels at are good, but I would like one that's still pretty fast to train on.

nicklashansen commented 4 months ago

Yea! I find DMControl tasks to be the most reliable for benchmarking (fairly low variance between seeds), I'd recommend using a handful of tasks from different embodiments so that you don't over-optimize for any single task. Something like (cheetah-run, walker-run, finger-turn-hard, humanoid-stand) might be a good task set for experimentation?

On a different note: I'm curious to know exactly what makes the jax implementation of tdmpc2 so much faster? I would be interested in using any insights here to optimize the pytorch version more.

edwhu commented 4 months ago

My guess is the JIT compilation of the MPPI planning procedure could make a big difference. It would be interesting to profile both codebases to figure out the actual bottlenecks.

This JAX implementation is missing some details, documenting here.

ShaneFlandermeyer commented 4 months ago

My guess is the JIT compilation of the MPPI planning procedure could make a big difference. It would be interesting to profile both codebases to figure out the actual bottlenecks.

Yeah, there are just lots of for loops for Jax to take advantage of in both training and planning. I imagine you could get a big performance boost in the main repo by taking advantage of more torch.compile tools

  • some heuristics missing like: planning iterations based on action space, adaptive discount factor based on episode length

I opted to leave these out and let the user define them if they want. Happy to add them in explicitly though!

edwhu commented 4 months ago

Benchmarks and improved logging are next up on my to-do list. Unfortunately, I have a conference deadline coming up and will have to grind on that for a couple weeks instead.

For now, I'll be a bit more qualitative: I've tested and verified that both implementations perform similarly in the HalfCheetah, Walker, and Humanoid mujoco environments. In these envs, the Jax agent's throughput is 80-90 FPS, while the Torch implementation sits at 15-20.

what GPU are you using? I'm geting around 50 FPS for HalfCheetahV4 on the Nvidia 3090.

ShaneFlandermeyer commented 4 months ago

Are you using the 4 envs/UTD=0.5 setting in the config? If so, the number of iterations per second reported by the progress bar is not the actual FPS, as each iteration corresponds to num_env environment steps. The bar output is only the FPS for 1 env (50 iter/s is around 200 FPS for 4 envs).

The numbers were from a laptop 4090, which I assume is similar to your 3090. If that's the true FPS, I'll dig into it tomorrow and see if I can identify possible performance regressions.

edwhu commented 4 months ago

I think you're right. I was wondering why the tqdm was lagging behind the actual global steps.

edwhu commented 4 months ago

A very preliminary comparison between the jax and original TDMPC2 codebase on DMC Humanoid Stand for 1M steps. Curves look good, but this still isn't completely apples-to-apples.

Screenshot 2024-05-31 at 5 00 47 PM

One implementation difference is that the Jax implementation uses an async vector env with 4 independent copies, and updates with an UTD of 0.5 I think TDMPC2 has an UTD of 1.0, and not sure if it's running 4 copies of the environment.

But overall, the trend looks roughly similar. I'm interested in the speed up - around 2hrs for 1M steps, so 8hrs total to hit the 4M steps in the original paper. Although this may change if we increase the UTD from 0.5 to 1.0 @nicklashansen do you have an estimate of the walltime of running TDMPC2 humanoid standup for 1M steps?

ShaneFlandermeyer commented 4 months ago

Correct me if I'm wrong @nicklashansen, but the original paper results are 1 env, UTD=1. I'll run 1M steps of Humanoid Stand with those settings from your most recent PR @edwhu. I imagine that'll give curves more similar to what's expected. Progress bar estimates 2.5 hours for 1M steps in 1 env.

The DMC env maker also applies an action repeat. So there may be a sneaky factor of 2 depending on if you define FPS in terms of simulator steps or agent interactions.

nicklashansen commented 4 months ago

Yes, the original paper uses 1 env and UTD=1. I believe that the PyTorch implementation is at roughly 1M policy steps (2M sim steps) per day with that setup. Here's some wall-time results comparing 1 env vs. 4 envs: https://github.com/nicklashansen/tdmpc2/issues/18#issuecomment-1939759448

ShaneFlandermeyer commented 4 months ago

Here's the 1 env UTD = 1 case for humanoid walk. The final reward looks reasonable, but there is some instability in the intermediate steps. Anyone with more DMC experience know if this looks reasonable?

humanoid-walk_single_env

edwhu commented 4 months ago

I think this looks reasonable for a single run, although @nicklashansen should confirm. We could also let it run for longer since instability during early training is very normal.

I see there are some raw TDMPC2 curves here, could you plot these on top of what you have now?

ShaneFlandermeyer commented 4 months ago

Overlayed figure:

humanoid-walk_single_env_compare

edwhu commented 4 months ago

That looks pretty good to me, barring the potential difference in logging between codebases for action repeat.

nicklashansen commented 4 months ago

The raw official numbers that we provide are evals, i.e., mean of multiple episodes at fixed intervals. If you're plotting the training episode returns then I would expect to see some variance / performance drops between individual episodes.

ShaneFlandermeyer commented 4 months ago

@nicklashansen Did you run into any mujoco physics errors in your DMControl experiments? We've been able to match the PyTorch implementation's performance in all the envs we've tried except for walker-run, which gets a high reward (>850) before throwing a mujoco error.

See #5 for ongoing discussion on the topic.

ShaneFlandermeyer commented 3 months ago

Here are the final results for several DMControl envs. The Jax implementation does very well!

walker-run humanoid-stand finger-turn-hard cheetah-run

nicklashansen commented 3 months ago

@ShaneFlandermeyer @edwhu This looks great! Somewhat surprised to see that the performance now exceeds the pytorch version, any idea what has changed?

edwhu commented 3 months ago

Might be the REDQ update for the q ensemble.

On Mon, Jun 10, 2024 at 1:48 PM Nicklas Hansen @.***> wrote:

@ShaneFlandermeyer https://github.com/ShaneFlandermeyer @edwhu https://github.com/edwhu This looks great! Somewhat surprised to see that the performance now exceeds the pytorch version, any idea what has changed?

— Reply to this email directly, view it on GitHub https://github.com/ShaneFlandermeyer/tdmpc2-jax/issues/1#issuecomment-2158957339, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADFZJZVSCHV53UX53WHGSG3ZGXRIXAVCNFSM6AAAAABE53NE4GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNJYHE2TOMZTHE . You are receiving this because you were mentioned.Message ID: @.***>

ShaneFlandermeyer commented 3 months ago

@nicklashansen Yep. We tried to match things as closely as possible to the original repo except for in the value estimation and actor loss, where the mean is taken over the full critic ensemble. We have to compute those values anyways due to vmap'ing, so it doesn't add much computation at all.

nicklashansen commented 3 months ago

Interesting, I will have to look into this. A few prior works (e.g. REDQ) show that subsampling the Q-functions generally performs better than taking the mean over the full ensemble, but perhaps this is somewhat setup-specific.