kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
935 stars 151 forks source link

The decoding module does not support multiple GPUs. #31

Closed lemonhu closed 2 years ago

lemonhu commented 5 years ago

The decode function does not support multiple GPUs, can see torchcrf/__init__.py#L117, the following bug will appear on multiple GPUs.

File "/root/anaconda2/envs/pytorch1.0/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in gather_map return type(out)(map(gather_map, zip(*outputs))) TypeError: zip argument # 1 must support iteration

Looking forward to your reply.

kmkurn commented 5 years ago

Hi, thanks for reporting. I personally have no access to multiple GPUs, so I don't think I can fix and test this. Any PR on this would be welcome.

Dhanachandra commented 4 years ago

There are two solutions: 1. return torch.Tensor instead of list in your _viterbi_decode method. return torch.Tensor(best_tags_list) Note: here you need to have same length sentences in a batch. Otherwise, there will be error in scatter_gather.py.

2. define a function in your network that calls the crf.decode method.

def predict(self, input_ids, attn_masks):
         emission = model(input_ids, attn_masks)
     prediction = self.crf.decode(emission, mask=attn_masks)
     return prediction

note: we have to call this method with model.module.predict(input_ids, attn_masks)