kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
944 stars 152 forks source link

Two questions about usage. Binary task and weights #84

Closed contribcode closed 3 years ago

contribcode commented 3 years ago

Hello,

I have two questions regarding the usage of the library, and since they are not issues, I just open one issue. If there is a more appropriate place for asking questions about the library, let me know and I will transfer them there.

Question 1

Firstly, I want to use the CRF layer for a binary task. The number of tags of the CRF layer will be one or two? I ask this because in many cases, for binary tasks, libraries' functions accept parameters just for the positive class.

Question 2

My second question is about weights. I want to use a CRF layer on top of a BERT model. My implementation is the following

def model_parameters(model_params, crf_params):
    yield from model_params
    yield from crf_params

class Bert_Clf(nn.Module):

    def __init__(self, bert_model_arg):
        super(Bert_Clf, self).__init__()
        self.bert_model = bert_model_arg
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(768, crf_n_tags)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert_model(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.fc(sequence_output)
        return logits

bert = BertModel.from_pretrained("bert-base-uncased")
model = Bert_Clf(bert_model=bert)
crf_layer = CRF(crf_n_tags, batch_first=True)
optimizer = optim.AdamW(model_parameters(model.parameters(), crf_layer.parameters()))

I define CRF layer out of the model definition, because it returns the log likelihood. subsequently, I pass the weights to the optimizer. Is this the rigth way to update to update the weights for the model, i.e. the model along with the CRF layer?

kmkurn commented 3 years ago

Hi, thanks for using the library! About your questions:

  1. In that case, number of tags should be 2.
  2. That seems correct to me, and there is itertools.chain in Python that could replace your model_parameters function. Apologies if you already know this.
contribcode commented 3 years ago

Thank you @kmkurn for your fast reply, you covered my questions. Thank you for the library too.