facebookresearch / hanabi_SAD

Simplified Action Decoder for Deep Multi-Agent Reinforcement Learning
Other
97 stars 35 forks source link

Issues with using auxillary task in the setup #17

Closed akileshbadrinaaraayanan closed 3 years ago

akileshbadrinaaraayanan commented 3 years ago

Hi @hengyuan-hu ,

I have a few questions regarding the usage of auxiliary task.

RuntimeError: Error(s) in loading state_dict for R2D2Net:
        size mismatch for net.0.weight: copying a param with shape torch.Size([512, 838]) from checkpoint, the shape in current model is torch.Size([512, 783]).

However, this goes away when I set --sad 1 in addition to giving --pred_weight 0.25. Why does enabling auxiliary task also imply enabling SAD? Can they not be done independently? This also aligns with what I see in https://github.com/facebookresearch/hanabi_SAD/blob/502d6a7a52028511704c944dffe1945194e10c3a/pyhanabi/tools/eval_model.py#L35

warning: pred_1st.weight not loaded
warning: pred_1st.bias not loaded
removing: pred.bias not used
removing: pred.weight not used

by looking at the load_weight in utils, I understand that the weights of the prediction network is not loaded but instead just set to the currently created model's random weights. https://github.com/facebookresearch/hanabi_SAD/blob/502d6a7a52028511704c944dffe1945194e10c3a/pyhanabi/utils.py#L214

Could you explain why this is done and how I can ensure my current agent uses the prediction network correctly?

hengyuan-hu commented 3 years ago

1) When we trained the AUX model we used SAD. So it expect to use the SAD version of the input at evaluation time. The evaluation script is hardcoded for simplicity. As the function name suggested, it is kind of legacy now... The evaluate_saved_model in eval.py is more up-to-date.

2) This is due to the mismatch of the module name. When the model was trained, the pred layer is called pred. But in the new code it is called pred_1st. The simplest solution is to change the name back to pred. Btw, if you only want to evaluate the model, then you really need the prediction weight.

3) Sorry this is a bug. Fixed it in the latest push https://github.com/facebookresearch/hanabi_SAD/commit/54a8d34f6ab192898121f8d3935339e63f1f4b35. Thanks!

akileshbadrinaaraayanan commented 3 years ago

Great, thanks @hengyuan-hu :) Closing this now will re-open if there are any other questions.