wayveai / mile

PyTorch code for the paper "Model-Based Imitation Learning for Urban Driving".
MIT License
365 stars 35 forks source link

What happens when we learn the world model and policy separately? #36

Closed return-sleep closed 1 year ago

return-sleep commented 1 year ago

Thank you for this excellent work. However some problem make me confused.

  1. While MILE jointly learn a model of the world and a policy \hat{a_t} ~ \pi(h_t, s_t), then the predicted action influence the latent dynamics s_t. But in the released code models/trainsition.py, it seems like MILE still uses the ground truth action to compute distribution of latent state? Did I miss some critical parts?
    self.active_inference = False
        if self.active_inference:
            print('ACTIVE INFERENCE!!')
    def imagine_step(self, h_t, sample_t, action_t, use_sample=True, policy=None):
        if self.active_inference:
            # Predict action with policy
            action_t = policy(torch.cat([h_t, sample_t], dim=-1))
  2. what will happen if we separate the training of world model and policy. While training world model we only consider the ground truth action and reconstruct image & segmentation. With pretrained world model , visual frames can be encoded to latent states, so policy network may be trained to map states to expert action like behavior cloning.
  3. Do these two approaches make a significant difference? Or is there a mutual reinforcement between behavior learning and model learning? Why joint learning rather than separate learning?
  4. closed-loop performance of our model with two different strategies: Reset state & Fully recurrent. The latter is easy to understanding ,but how can we reset state, treat it as the first frame?
anthonyhu commented 1 year ago

Hello!

  1. The future latent state prediction can either use the ground truth action or the predicted action from the policy (when self.active_inference = True. In practice, both approaches lead to roughly the same performance.
  2. and 3. We found that joint training of world modelling (understanding of the world), and policy learning (predict which action to take given the current state) was beneficial. In particular, the inferred state generalised better to unseen towns/weathers.
  3. What we call "reset state" means we always recompute the state using the full history of image context [o_1, ..., o_t] (so it's computationally more expensive).
return-sleep commented 1 year ago

Hello!

  1. The future latent state prediction can either use the ground truth action or the predicted action from the policy (when self.active_inference = True. In practice, both approaches lead to roughly the same performance.
  2. and 3. We found that joint training of world modelling (understanding of the world), and policy learning (predict which action to take given the current state) was beneficial. In particular, the inferred state generalised better to unseen towns/weathers.
  3. What we call "reset state" means we always recompute the state using the full history of image context [o_1, ..., o_t] (so it's computationally more expensive).

Thanks for your response!

anthonyhu commented 1 year ago

No problem!