johndpope / MegaPortrait-hack

Using Claude Opus to reverse engineer code from MegaPortraits: One-shot Megapixel Neural Head Avatars
https://arxiv.org/abs/2207.07621
68 stars 7 forks source link

es global descriptor - how to plug this back into Gbase #10

Closed johndpope closed 4 months ago

johndpope commented 4 months ago
class Gbase(nn.Module):
    def __init__(self):
        super(Gbase, self).__init__()
        self.appearanceEncoder = Eapp()
        self.motionEncoder = Emtn()
        self.warp_generator_s2c = WarpGenerator(in_channels=2048) # source-to-canonical
        self.warp_generator_c2d = WarpGenerator(in_channels=2048) # canonical-to-driving 
        self.G3d = G3d(in_channels=96)
        self.G2d = G2d(in_channels=96)

    def forward(self, xs, xd):
        vs, es = self.appearanceEncoder(xs)
        assert vs.shape[1:] == (96, 16, 64, 64), f"Expected vs shape (_, 96, 16, 64, 64), got {vs.shape}"

        Rs, ts, zs = self.motionEncoder(xs)
        Rd, td, zd = self.motionEncoder(xd)

        # Compute rotation/translation warping
        w_rt_s2c = compute_rt_warp(Rs, ts, invert=True)
        w_rt_c2d = compute_rt_warp(Rd, td, invert=False)

        # Compute expression warping
        # in the diagram
        w_em_s2c = self.warp_generator_s2c(zs) # # produce a 3D warping field w𝑠→
        w_em_c2d = self.warp_generator_c2d(zd)

        # Warp volumetric features (vs) using w_s2c to obtain canonical volume (vc)
        w_s2c = w_rt_s2c + w_em_s2c
        vc = apply_warping_field(vs, w_s2c)

        # Process canonical volume (vc) using G3d to obtain vc2d
        vc2d = self.G3d(vc)

        # Warp vc2d using w_c2d to impose driving motion
        w_c2d = w_rt_c2d + w_em_c2d
        vc2d_warped = apply_warping_field(vc2d, w_c2d)

        # Perform orthographic projection (P)
        vc2d_projected = torch.sum(vc2d_warped, dim=2)

        # Pass projected features through G2d to obtain the final output image (xhat)
        xhat = self.G2d(vc2d_projected)

        # 🤷‍♂️  from warp diagram only the zs is passed into warpgenerator  - but what to do with global descriptor - es

        return xhat