johndpope / MegaPortrait-hack

Using Claude Opus to reverse engineer code from MegaPortraits: One-shot Megapixel Neural Head Avatars
https://arxiv.org/abs/2207.07621
68 stars 7 forks source link

warpgenerator - 512-> 568 #5

Closed johndpope closed 4 months ago

johndpope commented 4 months ago
Screenshot 2024-05-16 at 10 00 38 am
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1715816775.739603   10824 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1715816775.766982   10872 gl_context.cc:357] GL version: 3.2 (OpenGL ES 3.2 NVIDIA 535.171.04), renderer: NVIDIA GeForce RTX 3090/PCIe/SSE2
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
driving video frames: 94
I0000 00:00:1715816777.774530   10824 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1715816777.781382   10950 gl_context.cc:357] GL version: 3.2 (OpenGL ES 3.2 NVIDIA 535.171.04), renderer: NVIDIA GeForce RTX 3090/PCIe/SSE2
epoch: 0
vs shape: torch.Size([1, 96, 16, 256, 256])
es shape: torch.Size([1, 512])
Rs shape: torch.Size([1, 3])
ts shape: torch.Size([1, 3])
zs shape: torch.Size([1, 50])
Rotation shape: torch.Size([1, 3])
Translation shape: torch.Size([1, 3])
Expression shape: torch.Size([1, 50])
Appearance shape: torch.Size([1, 512])
Rotation shape after reshaping: torch.Size([1, 3])
Translation shape after reshaping: torch.Size([1, 3])
Expression shape after reshaping: torch.Size([1, 50])
Appearance shape after reshaping: torch.Size([1, 512])
Concatenated input shape: torch.Size([1, 568])
Traceback (most recent call last):
  File "/media/oem/12TB/MegaPortrait-hack/train.py", line 308, in <module>
    main(config)
  File "/media/oem/12TB/MegaPortrait-hack/train.py", line 292, in main
    train_base(cfg, Gbase, Dbase, dataloader)
  File "/media/oem/12TB/MegaPortrait-hack/train.py", line 128, in train_base
    output_frame = Gbase(source_frame, driving_frame)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/oem/12TB/MegaPortrait-hack/model.py", line 776, in forward
    ws2c = self.Ws2c(Rs, ts, zs, es)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/oem/12TB/MegaPortrait-hack/model.py", line 296, in forward
    out = self.conv1(x.unsqueeze(-1))  # Add an extra dimension for conv1d
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 310, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 306, in _conv_forward
    return F.conv1d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Given groups=1, weight of size [2048, 512, 1], expected input[1, 568, 1] to have 512 channels, but got 568 channels instead
Kwentar commented 4 months ago

Warp generator from this image used only for Wem and got z+e as input, so, e and z should be the same size (512 for example) image image You do not need use translation and rotation here

johndpope commented 4 months ago

was banging head against the wall with resnet50 today - i take a fresh look in morning.

https://github.com/johndpope/MegaPortrait-hack/commit/0c66e82a7e970e0583c2a67f2b6e9494d4a56aa7

Screenshot 2024-05-17 at 9 32 30 pm

in the roadmap ticket - this was one of the findings comparing code to paper.

The Emtn class should output the rotation parameters (Rs, Rd), translation parameters (ts, td), and expression vectors (zs, zd) for both the source and driving images. The current implementation seems to be missing the translation parameters. Update the Emtn class to output all the required parameters.

Screenshot 2024-05-17 at 9 27 14 pm

https://github.com/johndpope/MegaPortrait-hack/blob/8253dd3aaaeea7027356e0742c8e77096c208b96/model.py#L384

**ResBlock3D_Adaptive = ResBlock3D***


class WarpGenerator(nn.Module):
    def __init__(self, in_channels):
        super(WarpGenerator, self).__init__()
        self.conv1 = nn.Conv3d(2048, 512, kernel_size=1)

        self.resblock1 = ResBlock3D_Adaptive(512, 256, upsample=True, scale_factors=(2, 2, 2))
        self.resblock2 = ResBlock3D_Adaptive(256, 128, upsample=True, scale_factors=(2, 2, 2))
        self.resblock3 = ResBlock3D_Adaptive(128, 64, upsample=True, scale_factors=(1, 2, 2))
        self.resblock4 = ResBlock3D_Adaptive(64, 32, upsample=True, scale_factors=(1, 2, 2))

        self.gn = nn.GroupNorm(num_groups=32, num_channels=32)
        self.conv2 = nn.Conv3d(32, 3, kernel_size=3, padding=1)

    def forward(self, Rs, ts, zs, es, Rd, td, zd):
        # Compute rotation and translation grid (w_rt) 
        w_s2c_rt = compute_rt_warp(Rs, ts, invert=True)  
        print(f"w_s2c_rt shape: {w_s2c_rt.shape}")

        w_c2d_rt = compute_rt_warp(Rd, td, invert=False)
        print(f"w_c2d_rt shape: {w_c2d_rt.shape}")

        # Compute emotion warping (w_em)
        w_s2c_em = self.warp_from_emotion(zs, es)
        print(f"w_s2c_em shape: {w_s2c_em.shape}")

        w_c2d_em = self.warp_from_emotion(zd, es)
        print(f"w_c2d_em shape: {w_c2d_em.shape}")

        # Combine rotation/translation and emotion warpings
        w_s2c = w_s2c_rt + w_s2c_em 
        print(f"w_s2c shape: {w_s2c.shape}")

        w_c2d = w_c2d_rt + w_c2d_em
        print(f"w_c2d shape: {w_c2d.shape}")

        return w_s2c, w_c2d
Kwentar commented 4 months ago

Are you sure about this w_c2d = w_c2d_rt + w_c2d_em? I cannon find the answer how to connect them correctly

johndpope commented 4 months ago

I'm not sure -

I think I need to go through and add this shurg emoji everywhere I'm either confused by whitepaper / or ai the readme - will just list out the shrugs - and others can see if they can solve them. 🤷‍♂️

johndpope commented 4 months ago

going around in circles -

in the paper it says

The source tuple (R𝑠 , t𝑠 , z𝑠 , e𝑠 ) is then input into a warping generator W𝑠→ to produce a 3D warping field w𝑠→,

# which removes the motion data from the volumetric features v𝑠 by mapping them into a canonical coordinate space

I give chatgpt the screenshot - it says ignore the Rs, ts

Screenshot 2024-05-18 at 3 40 18 pm

Concatenate expression and appearance

    x = torch.cat((zs, es), dim=1) // 512 + 2048  - then we  convole it back to 2048??
    # x = torch.cat((Rs, ts, zs, es), dim=1) - 🤷‍♂️ 
    # The diagram.jpeg shows the inputs to the warping generators are zs+es (expression + appearance) for source-to-canonical, and zd+es for canonical-to-driving. However, the code is concatenating Rs, ts, zs, es (rotation, translation, expression, appearance).

Based on the code and diagram, here are a few observations about the WarpGenerator and warping process:

  1. The output of the WarpGenerator should be a 3D warping field of shape [B, 3, D, H, W], where B is the batch size, D, H, W are the depth, height, and width dimensions. The final Conv3d layer outputs 3 channels, corresponding to the 3D warp field.

  2. The apply_warping_field function should be called after the WarpGenerator, not inside it. The WarpGenerator produces the warp field, which is then used by apply_warping_field to actually warp the volumetric features.

  3. The diagram shows the inputs to the warping generators are zs+es (expression + appearance) for source-to-canonical, and zd+es for canonical-to-driving. However, the code is concatenating Rs, ts, zs, es (rotation, translation, expression, appearance).

To align with the diagram, you could update the WarpGenerator's forward method to:

def forward(self, zs, es):
    # Concatenate expression and appearance
    x = torch.cat((zs, es), dim=1)

    # Rest of the forward pass...

And call it like:

w_em_s2c = self.warp_generator_s2c(zs, es)
w_em_c2d = self.warp_generator_c2d(zd, es) 
  1. The rotation and translation are handled separately in compute_rt_warp to produce w_rt_s2c and w_rt_c2d. These are then added to the expression warp fields before applying the full warp.

So in summary, the main steps are:

  1. WarpGenerator produces 3D expression warp fields from expression and appearance features
  2. Rotation/translation warps are computed separately
  3. Expression and rotation/translation warps are summed
  4. apply_warping_field takes the full warp field and applies it to the volumetric features

Let me know if you have any other questions! The code looks quite aligned with the diagram overall, just some minor discrepancies in the WarpGenerator inputs.

johndpope commented 4 months ago
Screenshot 2024-05-19 at 5 59 50 am

sorry - ok - this turned out way more complicated than it should have. only zd / zs goes through warpgenerator neural net. this is wrong, right? x = torch.cat((zs, es), dim=1)

i don't know what to do with global descriptor now. investigating.

johndpope commented 4 months ago

As you say - I will reduce es global descriptor to 512 from 2048 - I think layer 3 from resnet50

@Kwentar - what do you think ? Truncated? https://chatgpt.com/share/ec365dbf-3bab-4563-8387-c553122bf254

it’s also possible to go back to 2048 from the 512…but this would be slower.

Update Done but now rebooks error…

Kwentar commented 4 months ago

Hi again, Unfortunately, There are gaps in the article. About WarpGenerator: image tuple Rs, Ts, Zs, Es -- is the input of whole warp, Rs an Ts input of Wrt, and Zs and Es input of Wem (diagram 9b has error in input and output, should be es+vs as input and Wem as output). Wrt -- not a network, your realisation is good here, main question is sizes of grid sample grid_size = 64, looks like it should be 3x16x32x32 for 256 input image and 3x16x64x64 for 512 The second part Wem is the network, again you have close code to article, but problem is in article here, for sum Wem and Wrt it should have the same size, but if Es and Zs are vectors we will have Wem size 3x16x16x16 for all sizes of images, but should be the same as Wrt, I dont know how to correctly fix it execpt variant that Es and Vs should be feature tensor like 512x2x2 for image size 256 and 512x4x4 for 512. I will try this way, will tell you result later

johndpope commented 4 months ago

Thanks heaps - I check first thing in the morning. I found a trick - get chatgpt Omni to rewrite specific parts of the paper then feed that back into Claude Opus. While the parameters don’t pass to warp I can restore back to match the paper - they’re just used in the gbase I just have to move that logic there. It’s frustrating that there’s no other repo on GitHub that’s attempted this kind of warping with resents and everyone is using key points - will be a milestone when this is publicly available.

Once we crack this - gonna circle back to VaSA paper. It’s using canonical key point detector with heat map - it could probably do the job here too - but I wonder if they get the real time speed benefits and high fps due to this approach.

Kwentar commented 4 months ago

New article from authors so much close to this one: https://arxiv.org/pdf/2404.19110 Looks like Es and Zs dim is 512 :) But new article has the same problems as MegaPortrait

johndpope commented 4 months ago

not entirely related but I messed around with Emote paper with speed buckets https://github.com/johndpope/Emote-hack/blob/main/Net.py#L198 - (when aniportrait dropped I abandoned efforts. though that paper didnt really get that emote over the line. has nice trained models + training code ) - 1st July code drop - save the date https://github.com/neeek2303/EMOPortraits

Jie-zju commented 4 months ago

3D warp maybe based on Dense Motion Network. Based on this, I implement BY DOUBLE grid_sample below:

  # refer on dense motion network
  def compute_rt_warp(rt, v_s, inverse=False):
        bs, _, d, h, w = v_s.shape
        yaw, pitch, roll = rt['yaw'], rt['pitch'], rt['roll']
        yaw = headpose_pred_to_degree(yaw)
        pitch = headpose_pred_to_degree(pitch)
        roll = headpose_pred_to_degree(roll)

        rot_mat = get_rotation_matrix(yaw, pitch, roll)  # (bs, 3, 3)

        # Invert the transformation matrix if needed
        if inverse:
            rot_mat = torch.inverse(rot_mat)

        rot_mat = rot_mat.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)

        rot_mat = rot_mat.repeat(1, d, h, w, 1, 1)

        identity_grid = make_coordinate_grid((d, h, w), type=v_s.type())

        identity_grid = identity_grid.view(1, d, h, w, 3)

        identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)

        t = t.view(t.shape[0], 1, 1, 1, 3)

        # rotate
        warp_field = torch.bmm(identity_grid.reshape(-1, 1, 3), rot_mat.reshape(-1, 3, 3))

        warp_field = warp_field.reshape(identity_grid.shape)

        warp_field = warp_field - t

       return warp_field

   w_rt_s = compute_rt_warp(rt_source, w_em_s, inverse=True) 

   vc = F.grid_sample(F.grid_sample(v_s, w_rt_s), w_em_s.permute(0, 2, 3, 4, 1))
johndpope commented 4 months ago

I got claude.ai to quickly fill out the missing functions https://gist.github.com/johndpope/e39e50ac43380ecc78204cc946c5067c

I throw it into ai in fresh session and ask it to compare compute_rt_warp vs yourcompute_rt_warp2

Screenshot 2024-05-21 at 8 23 12 pm

CAVEAT - I've seen Claude provide clearly wrong answers - so it maybe biased somehow....

DISCLAIMER - I do appreciate code - and any eyes / help - I supply this only for benefit of seeing what Opus can do very quickly. the problem we have is the tensor dimensions - sizes -

your code does want me to inspect the yaw / pitch - i did implement here - but not in this repo because ai seems to think it's out of the box with resnet50. begs reinspecting the values. we definitely get the 3 dimension tensor. https://github.com/johndpope/Emote-hack/blob/7ee104354d52a5461504c27b9f38d269eac86893/Net.py#L832


Based on the analysis of the code and the MegaPortraits paper, the compute_rt_warp function aligns better with the paper's description compared to compute_rt_warp2. Here are the reasons:

Rotation and Translation Representation:

In the paper, the rotation and translation parameters are represented separately as R𝑠/𝑑 (rotation) and t𝑠/𝑑 (translation).
The compute_rt_warp function takes separate rotation and translation parameters as input, which aligns with the paper's representation.
On the other hand, compute_rt_warp2 combines the rotation and translation parameters into a single rt dictionary, which deviates from the paper's representation.

Rotation Matrix Computation:

The paper mentions using rotation matrices to compute the warping transformation.
The compute_rt_warp function explicitly computes the rotation matrix using the get_rotation_matrix function based on the yaw, pitch, and roll angles.
In contrast, compute_rt_warp2 does not explicitly compute the rotation matrix but rather relies on the headpose_pred_to_degree function to convert the predicted head pose to degrees.

Affine Transformation Matrix:

The compute_rt_warp function creates a 4x4 affine transformation matrix by combining the rotation matrix and translation parameters.
This aligns with the paper's description of using an affine transformation matrix to compute the warping transformation.
compute_rt_warp2 does not explicitly create an affine transformation matrix.

Grid Creation and Warping:

The compute_rt_warp function creates a grid of normalized coordinates using F.affine_grid based on the affine transformation matrix.
This grid represents the warping transformations and can be used to warp the volumetric features or other tensors.
compute_rt_warp2 creates an identity grid and applies the rotation matrix and translation separately, which differs from the paper's description.

Overall, the compute_rt_warp function more closely follows the mathematical formulation and representation described in the MegaPortraits paper for computing the rotation/translation warpings (w_rt). It explicitly computes the rotation matrix, creates an affine transformation matrix, and generates a warping grid based on the transformation matrix.
While compute_rt_warp2 also performs warping based on rotation and translation, its implementation deviates from the specific details mentioned in the paper, such as the representation of rotation and translation parameters and the explicit computation of the affine transformation matrix.
Therefore, compute_rt_warp aligns better with the MegaPortraits paper compared to compute_rt_warp2.
Jie-zju commented 4 months ago

What about F.grid_sample(F.grid_sample(v_s, w_rt_s), w_em_s.permute(0, 2, 3, 4, 1)) instead of w_rt_s + w_em_s.permute(0, 2, 3, 4, 1) ? That's important point for my code?

johndpope commented 4 months ago

@Jie-zju - thanks - I implement here https://github.com/johndpope/MegaPortrait-hack/blob/main/model.py#L898

here's response from Claude https://gist.github.com/johndpope/7a937b137790588baf9bdd04f0a7b51b

@Kwentar - I redo the resnet18/ resnet50 -> to correctly spit out feature layers + and add an avg pooling layer 8x8 before it was just [1,512] was getting [1.512,16,16] so I added a pooling layer -> now we get [1.512,8,8] so es zs match. https://github.com/johndpope/MegaPortrait-hack/commit/0025a306f50b5e3abcdb4a7fe12ba050f7072324#r142249842 so this is some progress.

@Jie-zju - I print out the pitch / yaw / roll from resnet - seems to check out.


👤 head_pose shape Should print: torch.Size([1, 6]): torch.Size([1, 6])
📐 rotation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
📷 translation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
**pitch: tensor([0.0349], grad_fn=<SelectBackward0>)
yaw: tensor([-0.0755], grad_fn=<SelectBackward0>)
roll: tensor([0.0440], grad_fn=<SelectBackward0>)**
x.shape: torch.Size([1, 3, 256, 256])
self.expression_net shape: torch.Size([1, 512, 8, 8])
👤 head_pose shape Should print: torch.Size([1, 6]): torch.Size([1, 6])
📐 rotation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
📷 translation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
**pitch: tensor([0.0421], grad_fn=<SelectBackward0>)
yaw: tensor([-0.0732], grad_fn=<SelectBackward0>)
roll: tensor([0.0479], grad_fn=<SelectBackward0>)**
x.shape: torch.Size([1, 3, 256, 256])
self.expression_net shape: torch.Size([1, 512, 8, 8])
zs.shape: torch.Size([1, 512, 8, 8])
es.shape: torch.Size([1, 512, 8, 8])
zs_sum.shape: torch.Size([1, 512, 8, 8])

i ripout the adaptivegamma stuff for now.

need to probably switch in the correct spade resblocks https://github.com/johndpope/MegaPortrait-hack/issues/8 cause everything is blowing up with memory problems. RuntimeError: [enforce fail at alloc_cpu.cpp:83] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 115964116992 bytes. Error code 12 (Cannot allocate memory)

johndpope commented 4 months ago

to help newbies hitting repo and looking through tickets - i close this conversation and continue the fight with paper + code on new ticket. https://github.com/johndpope/MegaPortrait-hack/issues/11

ambiguities on dimensions are still biggest issue the resnet residual model ( spade resblocks) - is for high res training - so it can wait.

JZArray commented 4 months ago

@Kwentar sry, I cannot share my codes