NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
9.21k stars 1.42k forks source link

LUKE finetuning advice #78

Open aakashb95 opened 2 years ago

aakashb95 commented 2 years ago

Hi Niels, Thanks for your notebook on fine-tuning LUKE on a custom dataset.

I recently came across this dataset on the hub: https://huggingface.co/datasets/xiaobendanyn/tacred

I have converted it to the format given by you in this notebook: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LUKE/Supervised_relation_extraction_with_LukeForEntityPairClassification.ipynb I have reused your model and dataset class as it is only changed loss function for experimentation.

There is a huge class imbalance in this dataset: image

Since no_relation (negative samples) are very important to us (as most real world text doesn't contain relationships between two entities), I thought of using a weighted CrossEnntropyLoss instead of equally sampling all labels. I have computed weights for each labels like so: image

I want the model to pay attention to the labels that are lesser in quantity, and I am passing the weights to the CrossEntropyLoss function:

# ...
 def common_step(self, batch, batch_idx):
        labels = batch["label"]
        del batch["label"]
        outputs = self(**batch)
        logits = outputs.logits

        criterion = torch.nn.CrossEntropyLoss(
            weight=normedWeights
        )
# ... rest of the code

The weighted CE approach hasn't helped as even on basic examples model predicts no_relation: image

Further, I used FocalLoss which is tailor-made for class imbalance problems.

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets, **kwargs):
        loss = nn.CrossEntropyLoss(reduction="none")
        CE_loss = loss(inputs, targets)
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * ((1 - pt) ** self.gamma) * CE_loss
        return F_loss.mean()

I still face the same problem. I am still debugging my code and changing a few hyper-parameters. What would be your advice on dealing with class imbalance in this dataset?

NielsRogge commented 2 years ago

Hi,

Thanks for your interest in LUKE! I think it makes sense to only train the model on sentences in which a relationship occurs.

This answer is really helpful: https://github.com/lavis-nlp/spert/issues/20#issuecomment-643607794