facebookresearch / EGG

EGG: Emergence of lanGuage in Games
MIT License
288 stars 100 forks source link

SinusoidalPositionEmbedding lenght does not match message length #188

Closed nicofirst1 closed 3 years ago

nicofirst1 commented 3 years ago

Expected Behavior

No errors

Detailed Description

When using the TransformerSenderReinforce with SinusoidalPositionEmbedding the training raises a runtime error:

Traceback (most recent call last):
  File "/home/dizzi/Desktop/hidden_egg/egg/zoo/coco_game/main.py", line 474, in <module>
    main()
  File "/home/dizzi/Desktop/hidden_egg/egg/zoo/coco_game/main.py", line 469, in main
    trainer.train(n_epochs=opts.n_epochs)
  File "/home/dizzi/Desktop/hidden_egg/egg/core/trainers.py", line 262, in train
    train_loss, train_interaction = self.train_epoch()
  File "/home/dizzi/Desktop/hidden_egg/egg/core/trainers.py", line 207, in train_epoch
    optimized_loss, interaction = self.game(*args, **kwargs)
  File "/home/dizzi/anaconda3/envs/egg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dizzi/Desktop/hidden_egg/egg/core/reinforce_wrappers.py", line 505, in forward
    message_length,
  File "/home/dizzi/anaconda3/envs/egg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dizzi/Desktop/hidden_egg/egg/core/reinforce_wrappers.py", line 607, in forward
    transformed = self.encoder(message, lengths)
  File "/home/dizzi/anaconda3/envs/egg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dizzi/Desktop/hidden_egg/egg/core/transformer.py", line 120, in forward
    message, key_padding_mask=padding_mask, attn_mask=attn_mask
  File "/home/dizzi/anaconda3/envs/egg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dizzi/Desktop/hidden_egg/egg/core/transformer.py", line 188, in forward
    x = self.embed_positions(x)
  File "/home/dizzi/anaconda3/envs/egg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dizzi/Desktop/hidden_egg/egg/core/transformer.py", line 39, in forward
    return x + t
RuntimeError: The size of tensor a (12) must match the size of tensor b (11) at non-singleton dimension 1

Where 11 is my max_len. This is due to the additional dimension concatenated here which brings x to be of dimension 12 while t is still of dimension max_len=11.

What is the reason for that additional zero at the end of the sequence?

robertodessi commented 3 years ago

Hi @nicofirst1 , thanks for spotting this!

The best hacky solution for you for now would be to pass _maxlen+1 as a parameter here.

Please see loosely related issues #137 and #138

The reason we add a zero is that zero is by convention an EOS symbol in EGG. If a zero was already generated by the model it will be spotted here, otherwise, for consistency across games and model we make sure that an EOS is always produced in each message, either in the last position or before.

P.S.: You are using a custom branch of EGG (currently no **kwargs in the master branch), just to be sure, can you please try if it happens even with the master branch? I am almost positive that it does but just to be sure that your changes did not modify anything else

nicofirst1 commented 3 years ago

You are using a custom branch of EGG (currently no **kwargs in the master branch), just to be sure, can you please try if it happens even with the master branch?

I removed the kwargs and still got the same error.

The best hacky solution for you for now would be to pass max_len+1 as a parameter here.

Are you planning to integrate this fix?

robertodessi commented 3 years ago

It will be fixed but no ETA yet. I wouldn't change that line adding max_len+1 since it'd be better to centralize the handling of max_len, maybe adding a +1 in util.py when parsing command line args. This would require checking it doesn't break anything else. Feel free to work on it if you like otherwise we'll fix it as soon as we can

robertodessi commented 3 years ago

Fixed in #219