zichuan-liu / NA2Q

[ICML'23] Official PyTorch Implementation of NA2Q, and a comprehensive benchmark based on pymarl
https://proceedings.mlr.press/v202/liu23be.html
11 stars 2 forks source link

Question About the VAE Loss Function Setup in NA^2Q #1

Open guestreturn opened 2 months ago

guestreturn commented 2 months ago

Hi! I am very interested in the work presented in NA^2Q. However, I have some confusion regarding the setup of one of the loss functions. In $\mathcal{L}_{\mathcal{vae}}$, why is the MSE loss calculated between $o_i$ and $\tilde{o}_i$? This does not seem to align with what is described in reference [1]. If the loss function is set up as described in your paper, it would cause the generated $\mathcal{M}$ to mask out items in the observation $o$ solely based on their numerical values.

Is that correct?

[1] Self-Supervised Discovering of Interpretable Features for Reinforcement Learning

guestreturn commented 2 months ago

Hi! I hope you can see this issue. I have discussed this problem with my colleagues, but we still cannot reach a correct conclusion. We are unsure of the motivation behind your setup of the loss function.

The loss function in Equation 6 implies that the output mask tends to retain dimensions with larger absolute values while removing those with smaller absolute values. But what is the significance of doing so? We initially believe this might be a mistake.

We sincerely hope you can respond to us promptly.

zichuan-liu commented 1 month ago

Hi, sorry for the slow reply, I'm not used to checking issues of github.

The main idea is to retain the relevant obs information (o-\tilde{o}) for effective predictions (the prediction information comes from Z->M) while compacting representations in obs (minimizing the region of masks by L1). This method is similar to the information bottleneck principle to find which parts can cause the effectiveness of a prediction task.

Thus the trade-off loss is L = o-\tilde{o}+ |m|, where o and m are normalized in (0,1)

Regarding "solely based on their numerical values", could you explain it more clearly?