facebookresearch / EGG

EGG: Emergence of lanGuage in Games
MIT License
281 stars 99 forks source link

TransformerSenderReinforce max_len parameter #247

Open mitjanikolaus opened 1 year ago

mitjanikolaus commented 1 year ago

The docstring of TransformerSenderReinforce mentions that the max_len parameter includes the EOS token: https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L687

However, the transformer can create a message of max_len (without EOS) and the EOS token will be appended afterwards: https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L835

So, to my understanding, the max_len parameter does actually not include the EOS token?

robertodessi commented 1 year ago

True, docstrings have not been updated after we went with the "force eos" decision. Do you want to open a PR about this as well? :)

robertodessi commented 1 year ago

Btw, there's some known inconsistencies in egg regarding max_len, see #137 and #138, if you feel like you would want to fix it contributions are always welcome :)

mitjanikolaus commented 1 year ago

I think these two issues are related, but require a bit more refactoring as it's not straightforward to change the behavior of find_lengths to return lengths that are 0 (these can't be handled by the pytorch RNN implementations)

robertodessi commented 1 year ago

Can we assume length is always > 0 unless the input is something like an empty tensor? We do check that length is greater than one in EGG so that should not break when call find_lenths

mitjanikolaus commented 1 year ago

I was referring to the solution 2 you proposed here: #138 In this case, a message starting with an EOS token would be treated as having a length of 0, which then causes issues when read by an RNN.

robertodessi commented 1 year ago

Technically a message starting with EOS should not be accepted. How would you handle it? We are giving the input through the receiver RNN regardless of the symbol (whether it's EOS or not) and just ignoring everything after EOS when computing the loss. Therefore, I guess the RNN implementations shouldn't fail

mitjanikolaus commented 1 year ago

To my mind the problem is this issue in pytorch.

robertodessi commented 1 year ago

But we'll never have a sequence of len 0. This is because 1/ we enforce max_len to be greater or equal than 1 https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L266

2/ if the user sets max_len equals to 1 and, though unlikely, the sender generates only the EOS, we would have a message of size 2: the sender-generated EOS and the one automatically appended to each message. This tensor is then given to the receiver RNN in case of a RNN receiver and all the EOS handling is done in the Game instance, see e.g. https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L579-L588

For the such reasons I don't think we'll encounter the above error

mitjanikolaus commented 1 year ago

If we consider the solution 2 you proposed here (#138), find_lengths would need to return 0 in case a message consists only of the EOS token. Or maybe I understood the proposal wrong?

robertodessi commented 1 year ago

Yes, but I think this was before we appended an EOS to every message, so that proposal is outdated. Do you have anything in mind? :)