Closed bpshaver closed 4 years ago
(At least, it isn't clear to me immediately what is causing the error and, more importantly: it looks like an index problem, not a data type problem.)
This is PyTorch's behaviour. torch.uint8
means ByteTensor
, so the indexing semantics is different than that of LongTensor
: it will select tensor elements corresponding to non-zero indices (see here). This requires the index tensor to have the same shape as the indexed tensor, hence the error. CRF.forward
requires tags
to be LongTensor
(docs) so you should convert beforehand.
Thanks! I can see this is a PyTorch issue, not a pytorch-crf issue. I have some reading to do, but at minimum it seems like the PyTorch people could improve their error messages in this regard.
I noticed this in my own code and spent some time debugging it. Here's a version of the code in the docs as an example:
The change in the
dtype
parameter oftags
fromtorch.long
totorch.uint8
introduces the following error:Which is unexpected behavior, as far as I can tell.