nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
272 stars 49 forks source link

Question: Pretrained encoders for image observations? #14

Closed truncs closed 4 months ago

truncs commented 5 months ago

Hi Nick,

Thanks for adding support for rgb observations to TDMPC2. I have a few questions regarding image observations -

Thanks

nicklashansen commented 5 months ago

Hi @truncs, these are great questions.

Regarding your questions on data augmentation + pretrained representations, I'd like to share a reference that aims to answer these questions in a broader RL/IL context: On Pre-Training for Visuo-Motor Control: Revisiting a Learning-from-Scratch Baseline. In this paper, we evaluate existing training pipelines for visual RL and find that (1) augmentation is a great regularizer for visual encoders learned from scratch, especially when trained with TD-learning, and (2) representations trained on in-domain data consistently beats frozen pretrained representations on exisiting visual RL benchmarks. The reasons for this are complex and we go into more detail in the paper (backed by data), but TLDR would be that (1) existing benchmarks don't really require much out-of-domain knowledge, and (2) unless we have very limited in-domain data (e.g. few-shot learning) then training on in-domain data should always yield the best performance given (1). That said, there are other reasons for why one would want to use pretrained representations for existing benchmarks, such as trading some task performance for computational efficiency. I am planning to add support for this, but don't have an ETA at the moment.

To answer your last question: I agree that this makes qualitative evaluation of the model somewhat challenging. This is very much an open question, but there are a few things that come to mind: open-loop control (plans can be visualized with the simulator and/or compared to closed-loop control), T-SNE-like visualizations, or one can train a decoder post-hoc and use that as a metric for how well information is retained. Surprisingly, I find that a converged TD-MPC model trained on cartpole is able to plan and execute trajectories 100s of steps into the future without any environment feedback (open-loop), and without losing balance. I don't really know what to do with that information, but it does seem like the model learns something sensible. And lastly, I just want to point out that good reconstruction is not perfectly aligned with control / task performance either and has its own drawbacks; this is a good reference: Objective Mismatch in Model-based Reinforcement Learning)

truncs commented 5 months ago

Thanks for the quick reply Nick!

Some follow ups -

  1. I tried using a image encoder without the augmentation and the system wasn't learning much. I will kickoff one with the augmentations and see how does that work.
  2. Any plans to support VC style regularization? I have been thinking about it but it seems like, it has be carefully done since all the samples would be from the same environment and can potentially have the exact same trajectory in the batch and thus making the variance/covar regularization non-sensical.
  3. I glanced through the paper - Objective Mismatch in Model-based Reinforcement Learning and while I agree with some of the thoughts my empirical results with dreamerV3/V2 have consistently shown that the reward increases as the model accuracy increases. Thinking more about it, it does make sense since a better model serves as a better open loop simulator which is used when training the actor/critic. I do think that they could be aligned better though. In the case of dreamerV3, the dynamics model conveniently seems to ignore all the dynamic objects in the scene but is able to predict with decent accuracy the changes in the static scene when an action is taken. The dynamics objects are important in the environment otherwise the system will collide with it but they are probably harder to reconstruct than the other pixels. This also what prompted me to think more about using a frozen encoder that wouldn't ignore these objects and thus forcing the dynamics system to predict them in the future frame?
nicklashansen commented 5 months ago
  1. Yes definitely apply augmentation! This is a good reference: Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels
  2. This would be interesting and something I would like to try, not sure about the specifics though atm. But if you have a large and reasonably diverse offline RL dataset I could see that working well.
  3. Yea! Model accuracy and downstream task performance are definitely correlated, but not perfectly. I think part of the reason why it works well for current simulated benchmarks is that those are visually very simple. If you have a more realistic environment with diverse backgrounds, distracting objects, etc. the conclusion might be different. This paper is a little bit along those lines: https://arxiv.org/abs/2309.00082 (and has experiments on TD-MPC + Dreamer)
return-sleep commented 5 months ago
  1. Yes definitely apply augmentation! This is a good reference: Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels
  2. This would be interesting and something I would like to try, not sure about the specifics though atm. But if you have a large and reasonably diverse offline RL dataset I could see that working well.
  3. Yea! Model accuracy and downstream task performance are definitely correlated, but not perfectly. I think part of the reason why it works well for current simulated benchmarks is that those are visually very simple. If you have a more realistic environment with diverse backgrounds, distracting objects, etc. the conclusion might be different. This paper is a little bit along those lines: https://arxiv.org/abs/2309.00082 (and has experiments on TD-MPC + Dreamer)

I have some confusion about this data augmentation, may I ask if consecutive frames in the time dimension receive the same offset, especially when we train the dynamics model to learn enc(x_t+1)=f(a_t,enc(x_t))?

nicklashansen commented 5 months ago
  1. Yes definitely apply augmentation! This is a good reference: Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels
  2. This would be interesting and something I would like to try, not sure about the specifics though atm. But if you have a large and reasonably diverse offline RL dataset I could see that working well.
  3. Yea! Model accuracy and downstream task performance are definitely correlated, but not perfectly. I think part of the reason why it works well for current simulated benchmarks is that those are visually very simple. If you have a more realistic environment with diverse backgrounds, distracting objects, etc. the conclusion might be different. This paper is a little bit along those lines: https://arxiv.org/abs/2309.00082 (and has experiments on TD-MPC + Dreamer)

I have some confusion about this data augmentation, may I ask if consecutive frames in the time dimension receive the same offset, especially when we train the dynamics model to learn enc(x_t+1)=f(a_t,enc(x_t))?

Yes, the same augmentation is applied to frames across the time dimension, but not across the batch dimension. We use the augmentation from https://github.com/facebookresearch/drqv2 without modification and haven't tried any alternatives.

Edit: I should add that we use the same augmentation in our TD-MPC1 pixel experiments as well, implemented here: tdmpc/src/algorithm/helper.py#L131. I did not really experiment with different visual encoders for TD-MPC1 nor TD-MPC2 so there's probably something to gain by tuning that, but I think our current default works well enough for most use cases. If you have multiple camera views you just build an encoder for each camera and average the features; we do that in this paper [MoDem-V2: Visuo-Motor World Models for Real-World Robot Manipulation](https://arxiv.org/abs/2309.14236) with 3 RGB views of 224x224 resolution.

nicklashansen commented 4 months ago

Closing this issue but feel free to reopen if you have any follow-up questions!

truncs commented 4 months ago

A follow up from the previous thread (also happy to email if that works better). I trained TDMPC2 with frozen DinoV2 as the encoder (black line in the graph) and it seems to work surprisingly well as compared to training from scratch with image augmentation from DrQ-V2 (pink). WIthout the image augmentation (orange) it is worse than with image augmentation. While the tabula rasa encoder has trouble scaling with more data, dinov2 seems to be scaling well with more data. My task is very different from the existing robotics benchmarks and is based on Unreal Engine.

So it does seem like for real world tasks the current way to training joint encoders in TDMPC2 leads to suboptimal representations and hence lower performance? Of course this hypothesis needs to be tested on more real world tasks. Happy to share more.

Screenshot_2024-02-21_10-07-32

nicklashansen commented 4 months ago

I agree that pretrained encoders should intuitively work better in settings that are more realistic / require a higher degree of generalization. I think most of the results I have seen comparing pretrained vs. end-to-end visual encoders have been largely inconclusively / very situational, e.g. pretrained encoders work pretty poorly in DMControl but can be quite competitive in navigation tasks + manipulation environments where the agent also has access to proprioceptive information. I would expect this to be true for most RL algorithms, not just TD-MPC1/2.