kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
943 stars 152 forks source link

Error with _validate on gpu #33

Closed light8lee closed 5 years ago

light8lee commented 5 years ago

I'm using this model on Python 3.6.5, Pytorch 1.0.1 on docker, here is the traceback:

  ...
  File "/share/E4G0/models/up_crf.py", line 25, in forward
    scores = self.crf(emissions, target_tags, input_masks.long())
  File "/home/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/anaconda3/lib/python3.6/site-packages/torchcrf/__init__.py", line 90, in forward
    self._validate(emissions, tags=tags, mask=mask)
  File "/home/anaconda3/lib/python3.6/site-packages/torchcrf/__init__.py", line 165, in _validate
    no_empty_seq_bf = self.batch_first and mask[:, 0].all()
RuntimeError: _th_all is not implemented for type torch.cuda.LongTensor

It seems that mask[:, 0].all() dont work on cuda.LongTensor

kmkurn commented 5 years ago

Hi,

The mask is expected to be a ByteTensor. Try changing your code to scores = self.crf(emissions, target_tags, input_masks.byte()).

light8lee commented 5 years ago

Thank you. I found the description in pytorch docs: all() methods are unique to torch.ByteTensor.