TencentARC / Open-MAGVIT2

Open-MAGVIT2: Democratizing Autoregressive Visual Generation
Apache License 2.0
388 stars 12 forks source link

Why not use `per_sample_entropy` loss for backward? #25

Closed xesdiny closed 1 month ago

xesdiny commented 1 month ago

I see LFQ result 3 loss

# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions (per_sample_entropy)
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch (codebooK_entropy)
# 3. commitent_loss (important) 

In VQLPIPSWithDiscriminator.py just print

...
  loss = nll_loss + g_loss + scale_codebook_loss + loss_break.commitment * self.commit_weight
    if disc_factor == 0:
        log = {"{}/total_loss".format(split): loss.clone().detach(),
               "{}/per_sample_entropy".format(split): loss_break.per_sample_entropy.detach(),
               "{}/codebook_entropy".format(split): loss_break.codebook_entropy.detach(),
...

Is it because it is just a reverse sharp protrusion of energy entropy as a metric?

""" Entropy loss of unnormalized logits

logits: Affinities are over the last dimension

https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279 LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024) """

xesdiny commented 1 month ago

Oops! I forgot to mention the weighting construct for 1,2 in lookup_free_quantize.py

...
    loss = (sample_minimization_weight * sample_entropy) - (
        batch_maximization_weight * avg_entropy
    )

    return sample_entropy, avg_entropy, loss