johndpope / MegaPortrait-hack

Using Claude Opus to reverse engineer code from MegaPortraits: One-shot Megapixel Neural Head Avatars
https://arxiv.org/abs/2207.07621
42 stars 7 forks source link

WarpingNet / flow fields / grids - reference #15

Closed johndpope closed 3 weeks ago

johndpope commented 1 month ago

https://github.com/OpenTalker/video-retalking/blob/d32e8e58248255e2d243eeaf3cba545dbe505ca8/utils/flow_util.py#L4

https://github.com/OpenTalker/video-retalking/blob/d32e8e58248255e2d243eeaf3cba545dbe505ca8/models/DNet.py#L88

class WarpingNet(nn.Module):
    def __init__(
        self, 
        image_nc=3, 
        descriptor_nc=256, 
        base_nc=32, 
        max_nc=256, 
        encoder_layer=5, 
        decoder_layer=3, 
        use_spect=False
        ):
        super( WarpingNet, self).__init__()

        nonlinearity = nn.LeakyReLU(0.1)
        norm_layer = functools.partial(LayerNorm2d, affine=True) 
        kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}

        self.descriptor_nc = descriptor_nc 
        self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
                                       max_nc, encoder_layer, decoder_layer, **kwargs)

        self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc), 
                                      nonlinearity,
                                      nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))

        self.pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, input_image, descriptor):
        final_output={}
        output = self.hourglass(input_image, descriptor)
        final_output['flow_field'] = self.flow_out(output)

        deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
        final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
        return final_output

import torch

def convert_flow_to_deformation(flow):
    r"""convert flow fields to deformations.

    Args:
        flow (tensor): Flow field obtained by the model
    Returns:
        deformation (tensor): The deformation used for warping
    """
    b,c,h,w = flow.shape
    flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
    grid = make_coordinate_grid(flow)
    deformation = grid + flow_norm.permute(0,2,3,1)
    return deformation

def make_coordinate_grid(flow):
    r"""obtain coordinate grid with the same size as the flow filed.

    Args:
        flow (tensor): Flow field obtained by the model
    Returns:
        grid (tensor): The grid with the same size as the input flow
    """    
    b,c,h,w = flow.shape

    x = torch.arange(w).to(flow)
    y = torch.arange(h).to(flow)

    x = (2 * (x / (w - 1)) - 1)
    y = (2 * (y / (h - 1)) - 1)

    yy = y.view(-1, 1).repeat(1, w)
    xx = x.view(1, -1).repeat(h, 1)

    meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
    meshed = meshed.expand(b, -1, -1, -1)
    return meshed    

def warp_image(source_image, deformation):
    r"""warp the input image according to the deformation

    Args:
        source_image (tensor): source images to be warped
        deformation (tensor): deformations used to warp the images; value in range (-1, 1)
    Returns:
        output (tensor): the warped images
    """ 
    _, h_old, w_old, _ = deformation.shape
    _, _, h, w = source_image.shape
    if h_old != h or w_old != w:
        deformation = deformation.permute(0, 3, 1, 2)
        deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
        deformation = deformation.permute(0, 2, 3, 1)
    return torch.nn.functional.grid_sample(source_image, deformation) 
robinchm commented 1 month ago

I feel the current warping code is almost correct by using affine_grid. And you can supply an identity affine matrix to affine_grid to obtain a starting meshgrid instead of using linspace and such. Something like:

# 2D example with batch_size 1
theta = torch.tensor([[[1,0,0], [0,1,0]]], dtype=torch.float32)
grid = F.affine_grid(theta, (1, 16, 64, 64), align_corners=True)

However I think we don't need such a "base grid".

The unnatural part in the current implementation is the normalization of w. affine_grid guarantees a grid in [-1, 1], while I can see that w_em and the sum may reach outside [-1, 1], I feel we should not manually perform any normalization. The network should learn that w_em needs to be small. That is we simply add the two warpings: w = w_em + w_rt and pass the result to grid_sample. The affine_grid already includes the "base grid" so there is no need to add a "base grid".

In short the apply_warping_field method can be replaced by v_canonical = F.grid_sample(v, w, mode='trilinear', padding_mode='border', align_corners=True)

johndpope commented 1 month ago

sorry pushed some broken code - fixing now.

robinchm commented 1 month ago

sorry pushed some broken code - fixing now.

I think you need to keep interpolate in the apply method. According to the diagram in the paper, warp should not change the size of vs. Therefore w should be scaled to the same size as vs.

I'm looking at the G3D module, which the current implementation is so different from what the paper shows. And the paper's diagram is also confusing - why is the depth upscaled and then downscaled during the downsampling phase?

johndpope commented 1 month ago

i dig through the oneshot freeview paper - and they have some flow illustrations https://ar5iv.labs.arxiv.org/html/2011.15126

I add to forward pass in gbase -

one of the key things in that paper - is they limit the keypoints to 20. because we are in resnet land - and no keypoints - not sure how we can throttle the amount of data in the fields.

Screenshot from 2024-05-27 22-16-47

UPDATE

I add some sampling - and make +ve / -ve different colours.

Screenshot from 2024-05-27 23-19-59

Screenshot from 2024-05-27 23-17-27

when I worked on this paper - https://github.com/johndpope/Emote-hack I used some code to determine head position / pitch / yaw / roll - I add some more debug into the https://github.com/johndpope/MegaPortrait-hack/blob/main/model.py#L1056

I add this to help illustrate things

self.visualize_warp_fields(xs, xd, w_s2c, w_c2d, Rs, ts, Rd, td)

I'm happy to chop out code - or accept a PR.

robinchm commented 4 weeks ago

I might be wrong but what I understand is that the "warp" field generated by affine_grid is the sampling grid, not the actual vector field. To obtain the vector field for visualization, we need to subtract from the warp tensor a "base grid" that is evenly sampling the space.

Another trap is that the last dim of the output of affine_grid, in our case 3, stores 3 values not in the order of (D, H, W), but in the order of (H, W, D). I don't know why it's designed this way, but it took me quite a while to figure out.

To correctly visualize the result of affine_grid as vector field, I use something like:

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d

import torch
import numpy as np
import torch.nn.functional as F

k = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]],
                 dtype=torch.float32)
base = F.affine_grid(k.unsqueeze(0), [1, 1, 2, 3, 4], align_corners=True)

k = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0]],
                 dtype=torch.float32)  # rotate
grid = F.affine_grid(k.unsqueeze(0), [1, 1, 2, 3, 4], align_corners=True)
grid = grid - base
grid = grid[0]

D, H, W, _ = grid.shape

fig = plt.figure()
ax = fig.add_subplot(projection="3d")

k, j, i = np.meshgrid(
    np.arange(0, D, 1),
    np.arange(0, H, 1),
    np.arange(0, W, 1),
    indexing="ij",
)

u = grid[..., 0].numpy()
v = grid[..., 1].numpy()
w = grid[..., 2].numpy()

ax.quiver(k, j, i, w, v, u, length=0.3)
plt.show()

There is an easier way I used to verify the correctness of the warping code. I reduced everything to 2D then use an image to test. Generate the affine grid using a known operation like rotation 90 + translation 0.2, run through grid_sample then output the image, it should work as expected (rotated 90, shifted so part of the output is black). Since we only use the 'w' tensor in grid_sample, we don't have to figure out its exact layout.

johndpope commented 4 weeks ago

nice. Screenshot from 2024-05-29 17-34-21