Closed lemonhu closed 2 years ago
Hi, thanks for reporting. I personally have no access to multiple GPUs, so I don't think I can fix and test this. Any PR on this would be welcome.
There are two solutions:
1. return torch.Tensor instead of list in your _viterbi_decode method.
return torch.Tensor(best_tags_list)
Note: here you need to have same length sentences in a batch. Otherwise, there will be error in scatter_gather.py.
2. define a function in your network that calls the crf.decode method.
def predict(self, input_ids, attn_masks):
emission = model(input_ids, attn_masks)
prediction = self.crf.decode(emission, mask=attn_masks)
return prediction
note: we have to call this method with model.module.predict(input_ids, attn_masks)
The decode function does not support multiple GPUs, can see torchcrf/__init__.py#L117, the following bug will appear on multiple GPUs.
Looking forward to your reply.