huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.85k stars 26.97k forks source link

Enable changing the loss function by making the hard-coded `loss_fct` an attribute of `BertForTokenClassification`. #33642

Open tom-010 opened 1 month ago

tom-010 commented 1 month ago

Feature request

In the method transformers.models.bert.modeling_bert.BertForTokenClassification.forward, the loss_fct = CrossEntropyLoss() is currently hard-coded. To change the loss function (e.g., to set class weights in CrossEntropyLoss), one must currently monkey-patch the model. By making loss_fct an attribute (e.g., self.loss_fct), users can simply replace it and use custom loss functions during training.

Motivation

The motivation behind this proposal stems from the need to change the loss function for fine-tuning a pre-trained BERT model for token classification, particularly when dealing with imbalanced classes. In my use case, I need to prioritize recall, as most tokens belong to the "other" class. To achieve this, I need to set custom weights in the CrossEntropyLoss, like this:

loss_fct = CrossEntropyLoss(weight=torch.tensor([0.1, 1.0, 1.0, 2.0, 2.0], device=self.device) 

However, since the loss function is hard-coded inside the forward method, modifying it currently requires overriding the entire method just to change one line, as shown here:

@patch
def forward(
        self: BertForTokenClassification,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], 'TokenClassifierOutput']:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            class_weights = torch.tensor([0.1, 1.0, 1.0, 2.0, 2.0], device=self.device)
            loss_fct = CrossEntropyLoss(weight=class_weights) # <------------------- only change
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

By turning loss_fct into an attribute, we could avoid the need monkey-patching. The change could be as simple as:

class_weights = torch.tensor([0.1, 1.0, 1.0, 2.0, 2.0], device=model.device)
model.loss_fct = CrossEntropyLoss(weight=class_weights)

This would leave existing code unchanged but make it easier to swap in a custom loss function when needed.

Your contribution

I am new to this repository and this would be my first pull request. I would like to ask if these types of changes are welcomed, and if it makes sense to proceed with submitting a pull request for this improvement.

LysandreJik commented 1 month ago

Hey @tom-010, the way transformers is designed is to expose a simple, common loss function, but to also return just the base logits in case you don't want that method.

Just don't pass the labels and compute the function outside the model, and you're good :ok_hand:

tom-010 commented 1 month ago

Thank you for the info, @LysandreJik! 😊 I initially followed that approach but encountered some issues while using the Trainer.

I ended up subclassing Trainer and overriding the compute_loss method, as suggested in the docstring:

def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """

Here’s an example of how I implemented it:

class CustomLossTrainer(Trainer):
    def __init__(self, *args, loss_fct, **kwargs):
        super().__init__(*args, **kwargs)
        # Store the custom loss function (e.g., CrossEntropyLoss with class weights)
        self.loss_fct = loss_fct

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        # Manually compute loss using the provided custom loss function
        loss = self.loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

However, I ran into the issue that this approach involves copying a significant portion of the code from Trainer.compute_loss. While this works, it’s not ideal because I would miss out on any future updates or changes made to the logic in Trainer.

It would be great if there were a cleaner way to just inject the custom loss function without needing to replicate existing code. IMO model.loss_fct = CrossEntropyLoss(weight=class_weights) would be a nice way. Or am I missing something?

If not: Patching and sub-classing works for me right now so I am fine, but I could contribute if this change is welcomed. If not, feeld free to close the issue :+1: Note, that the same issue would be in other models for token classifications as well, e.g. DebertaV2ForTokenClassification.

LysandreJik commented 1 month ago

Pinging @muellerzr for when he's back from leave if he wants to chime in; it would require quite a significant change across all models however

LysandreJik commented 1 month ago

Thanks for the feature request though! I understand it would make things much easier in your case