kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
944 stars 152 forks source link

[Question] The prediction is not change during training #115

Closed Viictte closed 1 year ago

Viictte commented 1 year ago

Hi! I tried to add a CRF module in Blenderbot for predicting the responding strategies for each sentence. More specifically, I extracted the wanted information from the encoder of Blenderbot and projected it to batch_size sentence_len num_tags tensor. However, I found the CRF layer as well as the projection linear layer is not updated during training.

Could you please help me to check where did I make mistakes?

'''

10 10 8 (batch_size turns num_strats)

    emission_matrix = nn.functional.softmax(strategy_logits, dim=-1)

    tags = turn_strategies.clone() - 1

    # Convert -100 values in turn_strategies to a valid index, e.g., 0
    tags.masked_fill_(is_pad, 0)
    # batch_size * num_turns

    masked_lm_loss = None
    if labels is not None:
        # for training
        crf_loss = -self.crf(emission_matrix, tags, ~is_pad, reduction="mean")

        print(self.strategy_projector.weight)

        loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1), reduction='none')
        loss = loss.view(labels.size(0), labels.size(1))
        # Cross entropy loss
        label_size = torch.sum(labels.ne(-100), dim=1).type_as(loss)
        # Averaging
        masked_lm_loss = torch.sum(loss) / torch.sum(label_size)
        # Calculate perplexity
        ppl_value = torch.exp(torch.mean(torch.sum(loss, dim=1).float() / label_size.float()))

'''

The ppl_value and crf_loss will be returned to the train function. loss = crf_loss + ppl_loss and then be backwarded.

kmkurn commented 1 year ago

On a glance, the code looks OK to me. If you could post a minimal reproducible example then that'd be helpful.

Viictte commented 1 year ago

Thank you for replying. I have solved this problem by using two optimizer with different learning rate since the learning rate for training Bert usually small like 3e-5 which might not be effective for training the CRF layer.