lavis-nlp / spert

PyTorch code for SpERT: Span-based Entity and Relation Transformer
MIT License
692 stars 148 forks source link

About add weight on loss function #25

Closed jmlongriver12 closed 4 years ago

jmlongriver12 commented 4 years ago

Hi,

I have a question about how to add weight in the loss function for example L = w1Ls + W2Lr

Thanks

markus-eberts commented 4 years ago

Hi, you can add loss weighting in line 42 in 'spert/loss.py'. Ideally, weights should be passed to the SpertLoss constructor, like so:

class SpERTLoss(Loss):
    def __init__(self, rel_criterion, entity_criterion, model, optimizer, scheduler, max_grad_norm, rel_weight, entity_weight):
        self._rel_criterion = rel_criterion
        self._entity_criterion = entity_criterion
        self._model = model
        self._optimizer = optimizer
        self._scheduler = scheduler
        self._max_grad_norm = max_grad_norm
        self._rel_weight = rel_weight
        self._entity_weight = entity_weight

    def compute(self, entity_logits, rel_logits, entity_types, rel_types, entity_sample_masks, rel_sample_masks):
        # entity loss
        entity_logits = entity_logits.view(-1, entity_logits.shape[-1])
        entity_types = entity_types.view(-1)
        entity_sample_masks = entity_sample_masks.view(-1).float()

        entity_loss = self._entity_criterion(entity_logits, entity_types)
        entity_loss = (entity_loss * entity_sample_masks).sum() / entity_sample_masks.sum()

        # relation loss
        rel_sample_masks = rel_sample_masks.view(-1).float()
        rel_count = rel_sample_masks.sum()

        if rel_count.item() != 0:
            rel_logits = rel_logits.view(-1, rel_logits.shape[-1])
            rel_types = rel_types.view(-1, rel_types.shape[-1])

            rel_loss = self._rel_criterion(rel_logits, rel_types)
            rel_loss = rel_loss.sum(-1) / rel_loss.shape[-1]
            rel_loss = (rel_loss * rel_sample_masks).sum() / rel_count

            # joint loss
            train_loss = self._entity_weight * entity_loss + self._rel_weight * rel_loss
        else:
            # corner case: no positive/negative relation samples
            train_loss = entity_loss

        train_loss.backward()
        torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._max_grad_norm)
        self._optimizer.step()
        self._scheduler.step()
        self._model.zero_grad()
        return train_loss.item()
jmlongriver12 commented 4 years ago

Thanks, can we learn the weights during the training instead of specifying the weights

markus-eberts commented 4 years ago

Of course you can. You may have a look at this paper. In PyTorch, you need to create the weights as tensors and add them to the optimizer, e.g. by registering them as nn.Parameters in the model ('SpERT' in 'spert/models.py').