kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
935 stars 151 forks source link

loss unstable #55

Closed MartinGOGO closed 4 years ago

MartinGOGO commented 4 years ago

II'm interested in using this library for Named Entity Recognition,but something bad happend. I'm using Pytorch to build a model with one embedding layer, one lstm layer, and a crf layer. Model structure is shown below。

class my_model(nn.Module):

  def __init__(self, vocab_size, embedding_size, hidden_size, num_classes, pad_idx):
        super(mymodel, self).__init__()
        self.embed = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx)
        self.num_classes = num_classes
        self.lstm = nn.LSTM(embedding_size, hidden_size)
        self.linear = nn.Linear(hidden_size, num_classes)
        self.crf = CRF(self.num_classes)
  def forward(self, texts, labels, masks):
        embedded = self.embed(texts)
        output, (hidden, cell) = model.lstm(embedded)
        out_vocab = model.linear(output)
        # out_vocab: [seq_len, batch_size, num_class]
        # labels: [seq_len, batch_size]
        # masks: [seq_len. batch_size] 
        loss = -(self.crf(out_vocab, labels, mask=masks))
        return loss

The question is that the LOSS is very high and very unstable during training. Loss often jump from 200+ to 20+, and then jump to 500+. I wonder if it's because I'm using this library incorrectly?

kmkurn commented 4 years ago

Hi, the code looks OK to me. By default, the loss is summed over tokens, so longer sentences have larger loss. You can pass reduction='token_mean' to have the loss averaged over tokens instead. It should be more stable.

MartinGOGO commented 4 years ago

Hi, the code looks OK to me. By default, the loss is summed over tokens, so longer sentences have larger loss. You can pass reduction='token_mean' to have the loss averaged over tokens instead. It should be more stable.

Thanks a lot. It looks better when I pass reduction='token_mean'.