ARISE-Initiative / robomimic

robomimic: A Modular Framework for Robot Learning from Demonstration
MIT License
655 stars 201 forks source link

BC Transformer is not autoregressive during inference #109

Closed mhyatt000 closed 1 year ago

mhyatt000 commented 1 year ago

Hello, this is my first issue so apologies if I missed something.

I have been training with the algo/bc.BC-Transformer and comparing the performance with BC RNN. However, while the transformer is trained in parallel, during inference time is autoregressive since the agent must interact with the RoboSuite environment. ie:

a0 = f(o0) a0,a1 = f(o0,o1) ...

problem

As it stands, the bc.py algorithms have no mechanism to use multiple observations with the get_action method. While this is not a problem for BC_RNN (as it maintains a hidden state reset after the episode), the BC_transformer does not have this hidden state and needs all the observations in order to accurately predict the next action in the sequence. Otherwise the model is not performing sequence modeling.

Further, in models/obs_nets.MIMO_Transformer the transformer model that sees only one observation, but expects cfg.context_length observations is broadcast across all timesteps in models/obs_nets.MIMO_Transformer.input_embedding()

in my experiments: tensor.shape(1,512) + tensor.shape(10,512) = tensor.shape(10,512)

solution

I have fixed this problem by retaining observations in a buffer and selecting the last n observations for the rollout. The observations are padded in the event that the agent has not yet experienced cfg.context_length observations. Rather than selecting the last observation in the sequence, the correct observation (relative to padding) is selected. This is sometimes the last observation. At the end of the rollout episode the buffer is cleared.

def get_action(self, obs_dict, goal_dict=None):
    """
    Get policy action outputs.
    Args:
        obs_dict (dict): current observation
        goal_dict (dict): (optional) goal
    Returns:
        action (torch.Tensor): action tensor
    """
    assert not self.nets.training

    obs_dict = TensorUtils.to_sequence(obs_dict)
    self.rollhist.append(obs_dict)

    zeroes = TensorUtils.single_apply(obs_dict, lambda x: torch.zeros(x.shape, device=x.device))
    n = self.context_length - len(self.rollhist)
    in_dict = (self.rollhist + [zeroes] * max(0,n))[-10:]

    timecat = lambda x, y: torch.cat((x, y), dim=1)
    from functools import reduce
    in_dict = reduce(lambda a,b: TensorUtils.both_apply(a,b, timecat),in_dict[1:],in_dict[0]) 

    return self.nets["policy"](in_dict, actions=None, goal_dict=goal_dict)[:, min(self.context_length,len(self.rollhist))-1, :]
    # return self.nets["policy"](obs_dict, actions=None, goal_dict=goal_dict)[:, -1, :]

def reset(self):
    """prepare transformer for autoregressive inference in rollout"""
    self.rollhist = []

If you like these changes, please let me know how I can incorporate them. I am planning to use RoboMimic in the future and would be interested to speak with one of the maintainers about contributing further improvements.

amandlek commented 1 year ago

Thank you for your clear explanation of the problem you are encountering.

We actually do something very similar to your proposed solution in the FrameStackWrapper, which is what is used at test-time during BC-Transformer rollouts.

Tagging @snasiriany @MBronars for visibility and in case they would like to comment further on this.

mhyatt000 commented 1 year ago

@amandlek Thanks for your reply! I was unaware of FrameStackWrapper.

I was planning to write in the following options: (for my own experiments)

Does RoboMimic already support these things? I was not aware of them when reading the documentation.

amandlek commented 1 year ago

Re: "training on many data sources simultaneously" - we plan to support this in the next version! Re: "multi-gpu / multi-node training" - this is not supported at the moment Re: "hindsight relabeling of goals (and packed hindsight relabeling)" - this is not supported at the moment either