vimalabs / VIMA

Official Algorithm Implementation of ICML'23 Paper "VIMA: General Robot Manipulation with Multimodal Prompts"
MIT License
778 stars 87 forks source link

The choice of the actiion decoder #11

Closed SiyuanHuang95 closed 1 year ago

SiyuanHuang95 commented 1 year ago

Hi, I noticed that you used torch.distributions.Distribution after MLP to get the final output, could you share some insights about this choice? What's the advantage compared with the direct usage of MLP and softmax?

Also, for the training procedure, should we ignore that header, and direct apply NLL loss with the output of MLP, or should we apply the NLL with the probability of that distribution? If also, could you give some simple code snippets to demonstrate the training usage?

BTW, congrats on the acceptance of ICML, well done!

Bests,

yunfanjiang commented 1 year ago

Hi there,

Thank you for your congratulatory words. To answer your questions

I noticed that you used torch.distributions.Distribution after MLP to get the final output, could you share some insights about this choice? What's the advantage compared with the direct usage of MLP and softmax?

Theoretically there is no difference between using categorical distribution and MLP + softmax. Personally, I found using torch distributions to be convenient since they implement uniformed interfaces that can work with different strategies to model action heads.

Also, for the training procedure, should we ignore that header, and direct apply NLL loss with the output of MLP, or should we apply the NLL with the probability of that distribution? If also, could you give some simple code snippets to demonstrate the training usage?

Sure, in the discrete case, let's say dist is a torch.distributions.Categorical instance predicted by the model, label is the discretized action, the loss is calculated with torch.nn.functional.cross_entropy. Since it takes unnormalized logits as inputs, we can just pass dist.logits (with proper reshape if necessary) into the loss function. For continuous case with unimodal Gaussian or GMM, I'd recommend to checkout these snippets: here and here.

SiyuanHuang95 commented 1 year ago

Great thanks for your @yunfanjiang reply and informative hints!

  1. MLP + Softmax case: Okay, I got it. BTW, I noticed that many works use MSE loss to train the policy network, turning the training into the regression problem. Have you ever conducted some experiments to compare them?

  2. Okay, thanks. But I noticed in your work you chose to use discretized ones. So what would be the big different between them?

yunfanjiang commented 1 year ago

Great thanks for your @yunfanjiang reply and informative hints!

  1. MLP + Softmax case: Okay, I got it. BTW, I noticed that many works use MSE loss to train the policy network, turning the training into the regression problem. Have you ever conducted some experiments to compare them?
  2. Okay, thanks. But I noticed in your work you chose to use discretized ones. So what would be the big different between them?

Thanks for the followup. To answer them

  1. I assume you were referring to those with continuous actions. In those cases we can totally opt to use a regression loss. However, since GMM is more expressive and can better handle distributional multimodality (which is the case for our benchmark, where multiple solutions exist for a single task), we only experimented with GMM for continuous action case.
  2. In our case we didn't observe significant difference empirically. So we opted to the simpler choice.

Hope these would be helpful.