nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
343 stars 71 forks source link

Handle visual observation. #9

Closed Jackory closed 9 months ago

Jackory commented 10 months ago

Thanks for this wonderful work. I see all the experiments are low-level state inputs, but DMC and meta-world support visual inputs. Have you tried tdmpc2 on these tasks? In my opinion, being able to process visual inputs is a critical capability of the world model.

nicklashansen commented 10 months ago

Hi @Jackory, this is a very valid question which I have gotten a few times already, so you're not alone in thinking that.

We did run benchmarks with visual inputs for TD-MPC1, and code for reproducing those experiments is available here: https://github.com/nicklashansen/tdmpc. We have also used TD-MPC1 for several other follow-up works in sim and real that use visual inputs, most recently (at time of writing) MoDem-V2 https://arxiv.org/abs/2309.14236 which takes 3x RGB cameras of 224x224 pixels as input along with robot proprioceptive state, so we are fairly confident in its ability to handle high-dimensional inputs.

For TD-MPC2, we wanted to mainly focus on the algorithmic properties and felt that visual benchmarking had a relatively lower priority compared to scaling / multi-task learning and algorithmic robustness (no hyperparameter-tuning). That said, I agree that releasing some visual benchmark results + public code for it would be valuable. We did not do any such benchmarking for TD-MPC2 (even internally), but I don't see any reason why it wouldn't perform comparably or better than TD-MPC1 in that setting as well. I'll definitely look into it!

Jackory commented 10 months ago

make sense

nicklashansen commented 9 months ago

@Jackory FYI I have started working on this. Commit https://github.com/nicklashansen/tdmpc2/commit/bfb19718981608fc8f032f5681d38506240eb948 adds basic support for RGB observations, but the current replay buffer is very slow for high-dimensional data. I'm working on a faster replay buffer in branch experimental and will merge into main when it's ready.

nicklashansen commented 9 months ago

@Jackory Commit https://github.com/nicklashansen/tdmpc2/commit/1f6c7771b92edd8d5502f910d5582ebf8ee88675 introduces full support for RGB observations for DMControl tasks, as well as a new, faster replay buffer. My own preliminary benchmarking suggests that performance of visual TD-MPC2 is comparable to that of TD-MPC1 / DrQ-v2 / DreamerV3 on these tasks. I'll run a more thorough benchmark soon. Closing this issue, but feel free to reopen if you have any follow-up questions!