huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.08k stars 26.56k forks source link

Is it possible to add simple custom pytorch-crf layer on top of TokenClassification model. It will make the model more robust. #20608

Open pratikchhapolika opened 1 year ago

pratikchhapolika commented 1 year ago

Model description

Is it possible to add simple custom pytorch-crf layer on top of TokenClassification model. It will make the model more robust. There should be simple Notebook tutorial which teaches us to add our own custom layer on top of Hugging face models for

By taking an example from dslim/bert-base-NER. Then add from torchcrf import CRF on top of it.

I am planning to do this, but I don't know how to get this feature coded. Any leads or Notebook example would be helpful.

from torchcrf import CRF

model_checkpoint = "dslim/bert-base-NER"
tokenizer = BertTokenizer.from_pretrained(model_checkpoint,add_prefix_space=True)
config = BertConfig.from_pretrained(model_checkpoint, output_hidden_states=True)
bert_model = BertForTokenClassification.from_pretrained(model_checkpoint,id2label=id2label,label2id=label2id,ignore_mismatched_sizes=True)

class BERT_CRF(nn.Module):

    def __init__(self, bert_model, num_labels):
        super(BERT_CRF, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(0.25)

        self.classifier = nn.Linear(4*768, num_labels)

        self.crf = CRF(num_labels, batch_first = True)

    def forward(self, input_ids, attention_mask,  labels=None, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)

        **sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**
        sequence_output = self.dropout(sequence_output)

        emission = self.classifier(sequence_output) # [32,256,17]
        labels=labels.reshape(attention_mask.size()[0],attention_mask.size()[1])

        if labels is not None:    
            loss = -self.crf(log_soft(emission, 2), labels, mask=attention_mask.type(torch.uint8), reduction='mean')
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return [loss, prediction]

        else:         
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return prediction
args = TrainingArguments(
    "spanbert_crf_ner-pos2",
    # evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    per_device_train_batch_size=8,
    # per_device_eval_batch_size=32
    fp16=True
    # bf16=True #Ampere GPU
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_data,
    # eval_dataset=train_data,
    # data_collator=data_collator,
    # compute_metrics=compute_metrics,
    tokenizer=tokenizer)

I get error on line **sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**

As outputs = self.bert(input_ids, attention_mask=attention_mask) gives the logits for tokenclassification. How can we get hidden states so that I can concate last 4 hidden states. so that I can dooutputs[1][-1]`?

Open source status

Provide useful links for the implementation

No response

NielsRogge commented 1 year ago

Hi,

Please use the forum for these kind of questions. We'd like to keep Github issues for bugs and feature requests.

Thanks!

pratikchhapolika commented 1 year ago

Hi,

Please use the forum for these kind of questions. We'd like to keep Github issues for bugs and feature requests.

Thanks!

This is kind of feature request only. @NielsRogge

sgugger commented 1 year ago

Models are fully defined in each modeling file in an independent fashion so you can easily copy/paste them and then customize them to your need :-)