facebookresearch / EGG

EGG: Emergence of lanGuage in Games
MIT License
286 stars 99 forks source link

undesirable behaviour of find_lengths function #138

Open mbaroni opened 3 years ago

mbaroni commented 3 years ago

def find_lengths(messages: torch.Tensor) -> torch.Tensor: """ :param messages: A tensor of term ids, encoded as Long values, of size (batch size, max sequence length). :returns A tensor with lengths of the sequences, including the end-of-sequence symbol (in EGG, it is 0). If no is found, the full length is returned (i.e. messages.size(1)).

This leads to counterintuitive behaviour in which, if max_len is 3, [1, 2, 3] and [1, 2, 0] have the same length.

robertodessi commented 3 years ago

Quickly thikning about it I see two possible options:

  1. not allowing (read raising an error/throwing and exception or returning a special value like 0 or -1) messages withouth EOS
  2. considering [1, 2, 3] as length 3 and [1, 2, 0] as length 2

I have a preference for 2. but am open to discuss it

tomkouwenhoven commented 1 year ago

Hi,

Is there any update on this issue? I am also working on variable-length communication, and adding an EOS token to each message (see three code lines below) causes the lengths to be of length: opts.max_len + 1 when the sender itself produces no EOS token. This is a bit counterintuitive when one specifies max_len to be a specific value and observes a length returned by the find_lengths function that is longer than the specified value.

sequence = torch.stack(sequence).permute(1, 0) zeros = torch.zeros((sequence.size(0), 1)).to(sequence.device) sequence = torch.cat([sequence, zeros.long()], dim=1)

In terms of the options you mentioned on November 5th 2020 I would also opt for option 2.

Is the current best solution still to increase max_len parameter with one as mentioned here #188 ?

Thanks in advance, Tom Kouwenhoven

robertodessi commented 1 year ago

Hi,

No concrete plans to work on this in the near future. Do you want to give it a go? :)