Closed MBAnslow closed 3 years ago
Sorry, I don't fully understand yet why this occurs. It's been some time a looked at this code, but from what I can recall, input_sequence should never be empty.
That is, because first input_sequence
is updated in line 156: input_sequence = self._sample(logits)
and that should never return something empty. Then it is updated in line 171 input_sequence = input_sequence[running_seqs]
, here we also assured that running_seqs has at least one element.
So I think you're workaround works, but I'm not sure why the problem occurs. Or am I missing something? Will be running the code these days to see if I can figure it out.
I think I've figured this out: this error occurs when input_sequence is a tensor that contains within it one index; i.e. input_sequence is a single scalar.
This occurs because when self._sample(logits)
is called, the output of torch.topk(logits, 1, dim=-1)
is squeezed. When logits
contains for more than n samples, n > 1, torch.topk(logits, 1, dim=-1)
will yield a tensor where one dimension is n; thus, when the tensor is squeezed, it will not eliminate the dimension of length n. But, when the number of samples is 1, all of the dimensions of the tensor will be squeezed out because they are all length 1, yielding instead a scalar / 0-d tensor.
There seem to be a few solutions to this. If we expect the output of self._sample(logits)
should always be a 1-d tensor, we could use sample.flatten()
, though this was introduced in torch 0.4.1. Alternatively, I think changing sample.squeeze()
to sample.view(-1)
will work, and should be compatible with whichever version of torch this is written in.
If we don't always expect a 1-d tensor, we could use unsqueeze
to add a dimension if we find that we've created a scalar tensor. Or, we could even squeeze only specifically those dimensions that are relevant.
I think I've figured this out: this error occurs when input_sequence is a tensor that contains within it one index; i.e. input_sequence is a single scalar.
This occurs because when
self._sample(logits)
is called, the output oftorch.topk(logits, 1, dim=-1)
is squeezed. Whenlogits
contains for more than n samples, n > 1,torch.topk(logits, 1, dim=-1)
will yield a tensor where one dimension is n; thus, when the tensor is squeezed, it will not eliminate the dimension of length n. But, when the number of samples is 1, all of the dimensions of the tensor will be squeezed out because they are all length 1, yielding instead a scalar / 0-d tensor.There seem to be a few solutions to this. If we expect the output of
self._sample(logits)
should always be a 1-d tensor, we could usesample.flatten()
, though this was introduced in torch 0.4.1. Alternatively, I think changingsample.squeeze()
tosample.view(-1)
will work, and should be compatible with whichever version of torch this is written in.If we don't always expect a 1-d tensor, we could use
unsqueeze
to add a dimension if we find that we've created a scalar tensor. Or, we could even squeeze only specifically those dimensions that are relevant.
I have this problem too and I worked around explicitly specifying the dimension to be squeezed out as you suggested, works with PyTorch 1.4:
# sample=sample.squeeze()
sample = sample.squeeze(-1).squeeze(-1) # No such .squeeze(dims=[1,2]) call in Pytorch 1.4
should now be solved by #26
There is a bug which arises when
len(running_seqs) == 1 and len(input_sequence.size()) == 0
. This is fairly rare, something like 1/100 times when using inferencing. It occurs when there is just 1 sequence yet to be completed and it isn't at the max length for that sentence so it still tries to generate. The issue is with trying to indexinput_sequence = input_sequence[running_seqs]
which gives the error "IndexError: too many indices for tensor of dimension 0." My fairly naive solution (still getting into working with tensors) is the following:Which simply doesn't try to index the input sequence as it is just a scalar. Then it needs to be handled differently at the top of the while loop too:
To add the dimensionality it needs to match the shape required to run it through the network.