johndpope / MegaPortrait-hack

Using Claude Opus to reverse engineer code from MegaPortraits: One-shot Megapixel Neural Head Avatars
https://arxiv.org/abs/2207.07621
68 stars 7 forks source link

cosface loss bug #44

Closed hazard-10 closed 3 months ago

hazard-10 commented 3 months ago

neg_sim = F.cosine_similarity(pos_pair[0], neg_pair[1], dim=0) i think it should be negpair[0], negpair[1] from https://github.com/johndpope/MegaPortrait-hack/blob/baee8bacec18492da56a09d200e605fc90ac6b03/train.py#L102C43-L102C51

samsara-ku commented 3 months ago

And I also ask to add a codes like this in here:

AS-IS:

loss = loss + torch.log(torch.exp(pos_dist) / (torch.exp(pos_dist) + neg_term))

TO-BE:

epsilon = 1e-8
loss = loss + torch.log(torch.exp(pos_dist) / (torch.exp(pos_dist) + neg_term + epsilon))

Based on my personal experience, I have a problem with training cycle consistency loss. When it comes with overfitting case (i.e. one identity image case), there is no change to face that problem. However when I tried to train with full dataset, I always face the NaN loss problem because cycle consistency loss converges into very small number case.

johndpope commented 3 months ago

at first glance - it seems code is fine for @hazard-10 concerns.

CHATGPT 4o https://chatgpt.com/c/eda359a9-b61f-4010-b742-b781bde94e01 Screenshot from 2024-06-15 16-09-50

but as sanity check - i rip out the cosine_loss function - and give the cosface paper + megaportraits paper (from reference folder) ask Claude Opus to fill it back in.

CLAUDE OPUS Screenshot from 2024-06-15 16-16-00

interestingly - it gives margin / scale additions -

please test - I merge what you think is best.

if you have code that supports 256 - please submit PR.

Here's a PyTorch implementation of the cosine loss function that aligns with the paper:

import torch
import torch.nn.functional as F

def cosine_loss(positive_pairs, negative_pairs, margin=0.5, scale=5):
    """
    Calculates the cosine loss for the positive and negative pairs.

    Args:
        positive_pairs (list): List of tuples containing positive pairs (z_i, z_j).
        negative_pairs (list): List of tuples containing negative pairs (z_i, z_j).
        margin (float): Margin value for the cosine distance (default: 0.5).
        scale (float): Scaling factor for the cosine distance (default: 5).

    Returns:
        torch.Tensor: Cosine loss value.
    """
    def cosine_distance(z_i, z_j):
        # Normalize the feature vectors
        z_i = F.normalize(z_i, dim=-1)
        z_j = F.normalize(z_j, dim=-1)

        # Calculate the cosine similarity
        cos_sim = torch.sum(z_i * z_j, dim=-1)

        # Apply the scaling and margin
        cos_dist = scale * (cos_sim - margin)

        return cos_dist

    # Calculate the cosine distance for positive pairs
    pos_cos_dist = [cosine_distance(z_i, z_j) for z_i, z_j in positive_pairs]
    pos_cos_dist = torch.stack(pos_cos_dist)

    # Calculate the cosine distance for negative pairs
    neg_cos_dist = [cosine_distance(z_i, z_j) for z_i, z_j in negative_pairs]
    neg_cos_dist = torch.stack(neg_cos_dist)

    # Calculate the cosine loss
    loss = -torch.log(torch.exp(pos_cos_dist) / (torch.exp(pos_cos_dist) + torch.sum(torch.exp(neg_cos_dist))))

    return loss.mean()

This implementation follows the cosine loss formulation described in the paper:

  1. The cosine distance is calculated for each pair of feature vectors (z_i, z_j) in both positive and negative pairs.
  2. The feature vectors are normalized using L2 normalization before calculating the cosine similarity.
  3. The cosine similarity is scaled by a factor (scale) and a margin (margin) is subtracted to obtain the cosine distance.
  4. The cosine loss is calculated using the formula: -log(exp(pos_cos_dist) / (exp(pos_cos_dist) + sum(exp(neg_cos_dist)))).
  5. The final loss is the mean of the cosine losses for all positive pairs.

To use this function, you need to provide the positive and negative pairs as lists of tuples, where each tuple contains the feature vectors (z_i, z_j) for a pair. The margin and scale parameters can be adjusted according to your needs.

Note: Make sure to import the necessary PyTorch modules (torch and torch.nn.functional) before using this function.

johndpope commented 3 months ago

i think this should be fine now

some nice insights here on how cosface works - https://chatgpt.com/share/abee7def-79c8-4aba-89bf-b3408c46bf18