Open mbaroni opened 3 years ago
Quickly thikning about it I see two possible options:
I have a preference for 2. but am open to discuss it
Hi,
Is there any update on this issue? I am also working on variable-length communication, and adding an EOS token to each message (see three code lines below) causes the lengths to be of length: opts.max_len + 1
when the sender itself produces no EOS token. This is a bit counterintuitive when one specifies max_len to be a specific value and observes a length returned by the find_lengths function that is longer than the specified value.
sequence = torch.stack(sequence).permute(1, 0)
zeros = torch.zeros((sequence.size(0), 1)).to(sequence.device)
sequence = torch.cat([sequence, zeros.long()], dim=1)
In terms of the options you mentioned on November 5th 2020 I would also opt for option 2.
Is the current best solution still to increase max_len
parameter with one as mentioned here #188 ?
Thanks in advance, Tom Kouwenhoven
Hi,
No concrete plans to work on this in the near future. Do you want to give it a go? :)
def find_lengths(messages: torch.Tensor) -> torch.Tensor: """ :param messages: A tensor of term ids, encoded as Long values, of size (batch size, max sequence length). :returns A tensor with lengths of the sequences, including the end-of-sequence symbol (in EGG, it is 0).
If no is found, the full length is returned (i.e. messages.size(1)).
This leads to counterintuitive behaviour in which, if max_len is 3, [1, 2, 3] and [1, 2, 0] have the same length.