ThilinaRajapakse / pytorch-transformers-classification

Based on the Pytorch-Transformers library by HuggingFace. To be used as a starting point for employing Transformer models in text classification tasks. Contains code to easily train BERT, XLNet, RoBERTa, and XLM models for text classification.
Apache License 2.0
306 stars 97 forks source link

Ability to pass custom pytorch loss function #9

Closed pythonometrist closed 5 years ago

pythonometrist commented 5 years ago

I am trying to figure out if I can pass a custom loss to the underlying bert model. Is it something I can do from your code or do I need to mess with the Bert models in pytorch _transformer - the issue is I can locate the next sentence, masked label and LM modules . Not sure which one accesses the binary label model. Any tips / suggestions? Would be useful functionality for biased samples...

pythonometrist commented 5 years ago

It looks like the loss is being assigned in BertForSequenceClassification, and it doesnt look like I can over ride the loss function from outside of the huggingface implementation

ThilinaRajapakse commented 5 years ago

You should be able to. Subclass BertForSequenceClassification and build your own module. BertForSequenceClassification is simply a torch.nn.Module. You can override the forward() method to use your custom loss function.

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, labels=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids, 
                            head_mask=head_mask)

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)

Replace CrossEntropyLoss() with your custom loss function.

pythonometrist commented 5 years ago

Makes sense - Let me take a shot at it. Thanks!

ThilinaRajapakse commented 5 years ago

Good luck!

pythonometrist commented 5 years ago

It works!

ThilinaRajapakse commented 5 years ago

Awesome!