kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
935 stars 151 forks source link

Unexpected IndexError is result of dtype change #44

Closed bpshaver closed 4 years ago

bpshaver commented 4 years ago

I noticed this in my own code and spent some time debugging it. Here's a version of the code in the docs as an example:

import torch
from torchcrf import CRF
num_tags = 5 
model = CRF(num_tags)

seq_length = 3 
batch_size = 2 
emissions = torch.randn(seq_length, batch_size, num_tags)
tags = torch.tensor([
  [0, 1], [2, 4], [3, 1]
], dtype=torch.uint8)

log_likelihood = model(emissions, tags)

The change in the dtype parameter of tags from torch.long to torch.uint8 introduces the following error:


IndexError: The shape of the mask [2] at index 0 does not match the shape of the indexed tensor [5] at index 0

Which is unexpected behavior, as far as I can tell.

bpshaver commented 4 years ago

(At least, it isn't clear to me immediately what is causing the error and, more importantly: it looks like an index problem, not a data type problem.)

kmkurn commented 4 years ago

This is PyTorch's behaviour. torch.uint8 means ByteTensor, so the indexing semantics is different than that of LongTensor: it will select tensor elements corresponding to non-zero indices (see here). This requires the index tensor to have the same shape as the indexed tensor, hence the error. CRF.forward requires tags to be LongTensor (docs) so you should convert beforehand.

bpshaver commented 4 years ago

Thanks! I can see this is a PyTorch issue, not a pytorch-crf issue. I have some reading to do, but at minimum it seems like the PyTorch people could improve their error messages in this regard.