Closed johndpope closed 4 months ago
class Gbase(nn.Module): def __init__(self): super(Gbase, self).__init__() self.appearanceEncoder = Eapp() self.motionEncoder = Emtn() self.warp_generator_s2c = WarpGenerator(in_channels=2048) # source-to-canonical self.warp_generator_c2d = WarpGenerator(in_channels=2048) # canonical-to-driving self.G3d = G3d(in_channels=96) self.G2d = G2d(in_channels=96) def forward(self, xs, xd): vs, es = self.appearanceEncoder(xs) assert vs.shape[1:] == (96, 16, 64, 64), f"Expected vs shape (_, 96, 16, 64, 64), got {vs.shape}" Rs, ts, zs = self.motionEncoder(xs) Rd, td, zd = self.motionEncoder(xd) # Compute rotation/translation warping w_rt_s2c = compute_rt_warp(Rs, ts, invert=True) w_rt_c2d = compute_rt_warp(Rd, td, invert=False) # Compute expression warping # in the diagram w_em_s2c = self.warp_generator_s2c(zs) # # produce a 3D warping field w𝑠→ w_em_c2d = self.warp_generator_c2d(zd) # Warp volumetric features (vs) using w_s2c to obtain canonical volume (vc) w_s2c = w_rt_s2c + w_em_s2c vc = apply_warping_field(vs, w_s2c) # Process canonical volume (vc) using G3d to obtain vc2d vc2d = self.G3d(vc) # Warp vc2d using w_c2d to impose driving motion w_c2d = w_rt_c2d + w_em_c2d vc2d_warped = apply_warping_field(vc2d, w_c2d) # Perform orthographic projection (P) vc2d_projected = torch.sum(vc2d_warped, dim=2) # Pass projected features through G2d to obtain the final output image (xhat) xhat = self.G2d(vc2d_projected) # 🤷♂️ from warp diagram only the zs is passed into warpgenerator - but what to do with global descriptor - es return xhat