huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.49k stars 26.66k forks source link

checkpoints are not saved after implementing a custom loss #12484

Closed yhifny closed 3 years ago

yhifny commented 3 years ago

I subclassed Trainer class to implement the supervised constructive loss function (code found using google search). The code is running but it does not save the models checkpoints ` class SupervisedContrastiveLoss(nn.Module): def init(self, temperature, device): """ Implementation of the loss described in the paper Supervised Contrastive Learning : https://arxiv.org/abs/2004.11362 :param temperature: int """ super(SupervisedContrastiveLoss, self).init() self.temperature = temperature self.device = device

def forward(self, projections, targets):
    """
    :param projections: torch.Tensor, shape [batch_size, projection_dim]
    :param targets: torch.Tensor, shape [batch_size]
    :return: torch.Tensor, scalar
    """
    projections = F.normalize(projections, p=2, dim=1)
    dot_product_tempered = torch.mm(projections, projections.T) / self.temperature
    # Minus max for numerical stability with exponential. Same done in cross entropy. Epsilon added to avoid log(0)
    exp_dot_tempered = (
        torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5
    )

    mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets).to(self.device)
    mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(self.device)
    mask_combined = mask_similar_class * mask_anchor_out
    cardinality_per_samples = torch.sum(mask_combined, dim=1)

    log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)))
    supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / cardinality_per_samples
    supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample)

    return supervised_contrastive_loss

class SCLTrainer(Trainer):

def __init__(self,temperature, loss_weight, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.add_callback(EarlyStoppingCallback())
    self.temperature = temperature
    self.loss_weight = loss_weight
    self.device = self.args.device
    print ("SCL loss_weight: ", self.loss_weight)
    print ("SCL temperature: ", self.temperature)

def compute_loss(self, model, inputs, return_outputs=False):
    labels = inputs.pop("labels")
    outputs = model(**inputs)
    feature_vectors = outputs.hidden_states[-1] [:,0,:]
    logits = outputs["logits"] if isinstance(outputs, dict) else outputs[0]

    #print ("feature_vectors.shape: ", feature_vectors.shape)
    #print ("logits.shape: ", logits.shape)
    self.ce_loss  = nn.CrossEntropyLoss().to(self.device)
    self.scl_loss = SupervisedContrastiveLoss(self.temperature,self.device).to(self.device)

    loss = (1-self.loss_weight) * self.ce_loss(logits, labels) + self.loss_weight * self.scl_loss(feature_vectors, labels)

    return (loss, outputs) if return_outputs else loss

` can you please let me know where is the problem?

LysandreJik commented 3 years ago

Hello! Could you provide the command you ran to run your script? If it's not an official example script, do you have the code you used handy? Thanks!

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.