TingtingLiao / TADA

[3DV 2024] Official Repository for "TADA! Text to Animatable Digital Avatars".
https://tada.is.tue.mpg.de
MIT License
280 stars 13 forks source link

The implementation of interpolated latent code of Equation(6) #18

Closed A-pril closed 3 months ago

A-pril commented 3 months ago

I tried to find how to implement the interpolated latent code of Equation(6) in paper, the z = alpha z^I + (1-alpha)z^N. I have read the code, it seems that you just use image latent code and normal latent code for loss separately? Maybe it's a Equivalent operation?

    def train_step(self, data, is_full_body):
           ......
           loss = self.guidance.train_step(dir_text_z, image_annel).mean()
            if not self.dpt: 
                # normal sds
                loss += self.guidance.train_step(dir_text_z, normal).mean()
                # latent mean sds
                # loss += self.guidance.train_step(dir_text_z, torch.cat([normal, image.detach()])).mean() * 0.1
            else:
                if p_iter < 0.3 or random.random() < 0.5: #  
                    # normal sds
                    loss += self.guidance.train_step(dir_text_z, normal).mean() # use normal map directly
                elif self.dpt is not None :
                    # normal image loss
                    dpt_normal = self.dpt(image) # estimate normal map
                    dpt_normal = (1 - dpt_normal) * alpha + (1 - alpha)

                    lambda_normal = self.opt.lambda_normal * min(1, self.global_step / self.opt.iters)
                    loss += lambda_normal * (1 - F.cosine_similarity(normal, dpt_normal).mean())

Waiting for your reply, thanks a lot!

TingtingLiao commented 3 months ago

Hi,

Its not separate, you can see the code here: https://github.com/TingtingLiao/TADA/blob/8909a13259df0b4649455e87670a39fd70cdac83/lib/guidance/sd.py#L104 .

Best, Tingting

April @.***> 于2024年7月11日周四 06:41写道:

I tried to find how to implement the interpolated latent code of Equation(6) in paper, the z = alpha z^I + (1-alpha)z^N. I have read the code, it seems that you just use image latent code and normal latent code for loss separately? Maybe it's a Equivalent operation?

def train_step(self, data, is_full_body):
       ......
       loss = self.guidance.train_step(dir_text_z, image_annel).mean()
        if not self.dpt:
            # normal sds
            loss += self.guidance.train_step(dir_text_z, normal).mean()
            # latent mean sds
            # loss += self.guidance.train_step(dir_text_z, torch.cat([normal, image.detach()])).mean() * 0.1
        else:
            if p_iter < 0.3 or random.random() < 0.5: #
                # normal sds
                loss += self.guidance.train_step(dir_text_z, normal).mean() # use normal map directly
            elif self.dpt is not None :
                # normal image loss
                dpt_normal = self.dpt(image) # estimate normal map
                dpt_normal = (1 - dpt_normal) * alpha + (1 - alpha)

                lambda_normal = self.opt.lambda_normal * min(1, self.global_step / self.opt.iters)
                loss += lambda_normal * (1 - F.cosine_similarity(normal, dpt_normal).mean())

Waiting for your reply, thanks a lot!

— Reply to this email directly, view it on GitHub https://github.com/TingtingLiao/TADA/issues/18, or unsubscribe https://github.com/notifications/unsubscribe-auth/AK473GCRHGI546BWC7KBLQLZLXWEXAVCNFSM6AAAAABKWBX2S6VHI2DSMVQWIX3LMV43ASLTON2WKOZSGQYDEMJQGMZTSNY . You are receiving this because you are subscribed to this thread.Message ID: @.***>

A-pril commented 3 months ago

Thanks for reply I have read this part code and try to debug and watch the shape of the latents, it's [1, 4, 64, 64]. Cause you use self.guidance.train_step(image_annel) and self.guidance.train_step(normal) separately, so I think that the torch.mean operation in dim0 doesn't work?

In my opinion, this part Commented-out code seems to be consistent with the idea of the paper. https://github.com/TingtingLiao/TADA/blob/8909a13259df0b4649455e87670a39fd70cdac83/lib/trainer.py#L276

TingtingLiao commented 3 months ago

Yes, this line is the latent interpolation loss