kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
935 stars 151 forks source link

[Question] mask of the first timestep must all be on #110

Closed antct closed 1 year ago

antct commented 1 year ago

Hi, I found there is a logic in the code to check the mask, which is to ensure there is at least one token in the input sequence.

if mask is not None:
    if emissions.shape[:2] != mask.shape:
        raise ValueError(
            'the first two dimensions of emissions and mask must match, '
            f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
    no_empty_seq = not self.batch_first and mask[0].all()
    no_empty_seq_bf = self.batch_first and mask[:, 0].all()
    if not no_empty_seq and not no_empty_seq_bf:
        raise ValueError('mask of the first timestep must all be on')

As others mentioned, #46 #85 #61 , the first token of Bert Encoder is always [CLS], and it's not expected to be calculated. I think the correct way to check the mask is:

no_empty_seq_bf = self.batch_first and torch.all(torch.any(mask, dim=1), dim=0)
kmkurn commented 1 year ago

Hi, please see a relevant issue #46.