Open mitjanikolaus opened 1 year ago
True, docstrings have not been updated after we went with the "force eos" decision. Do you want to open a PR about this as well? :)
Btw, there's some known inconsistencies in egg regarding max_len, see #137 and #138, if you feel like you would want to fix it contributions are always welcome :)
I think these two issues are related, but require a bit more refactoring as it's not straightforward to change the behavior of find_lengths
to return lengths that are 0 (these can't be handled by the pytorch RNN implementations)
Can we assume length is always > 0 unless the input is something like an empty tensor? We do check that length is greater than one in EGG so that should not break when call find_lenths
I was referring to the solution 2 you proposed here: #138 In this case, a message starting with an EOS token would be treated as having a length of 0, which then causes issues when read by an RNN.
Technically a message starting with EOS should not be accepted. How would you handle it? We are giving the input through the receiver RNN regardless of the symbol (whether it's EOS or not) and just ignoring everything after EOS when computing the loss. Therefore, I guess the RNN implementations shouldn't fail
To my mind the problem is this issue in pytorch.
But we'll never have a sequence of len 0. This is because 1/ we enforce max_len to be greater or equal than 1 https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L266
2/ if the user sets max_len equals to 1 and, though unlikely, the sender generates only the EOS, we would have a message of size 2: the sender-generated EOS and the one automatically appended to each message. This tensor is then given to the receiver RNN in case of a RNN receiver and all the EOS handling is done in the Game instance, see e.g. https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L579-L588
For the such reasons I don't think we'll encounter the above error
If we consider the solution 2 you proposed here (#138), find_lengths
would need to return 0
in case a message consists only of the EOS token. Or maybe I understood the proposal wrong?
Yes, but I think this was before we appended an EOS to every message, so that proposal is outdated. Do you have anything in mind? :)
The docstring of
TransformerSenderReinforce
mentions that themax_len
parameter includes the EOS token: https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L687However, the transformer can create a message of
max_len
(without EOS) and the EOS token will be appended afterwards: https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L835So, to my understanding, the
max_len
parameter does actually not include the EOS token?