Closed johndpope closed 4 months ago
Warp generator from this image used only for Wem and got z+e as input, so, e and z should be the same size (512 for example) You do not need use translation and rotation here
was banging head against the wall with resnet50 today - i take a fresh look in morning.
https://github.com/johndpope/MegaPortrait-hack/commit/0c66e82a7e970e0583c2a67f2b6e9494d4a56aa7
in the roadmap ticket - this was one of the findings comparing code to paper.
The Emtn class should output the rotation parameters (Rs, Rd), translation parameters (ts, td), and expression vectors (zs, zd) for both the source and driving images. The current implementation seems to be missing the translation parameters. Update the Emtn class to output all the required parameters.
**ResBlock3D_Adaptive = ResBlock3D***
class WarpGenerator(nn.Module):
def __init__(self, in_channels):
super(WarpGenerator, self).__init__()
self.conv1 = nn.Conv3d(2048, 512, kernel_size=1)
self.resblock1 = ResBlock3D_Adaptive(512, 256, upsample=True, scale_factors=(2, 2, 2))
self.resblock2 = ResBlock3D_Adaptive(256, 128, upsample=True, scale_factors=(2, 2, 2))
self.resblock3 = ResBlock3D_Adaptive(128, 64, upsample=True, scale_factors=(1, 2, 2))
self.resblock4 = ResBlock3D_Adaptive(64, 32, upsample=True, scale_factors=(1, 2, 2))
self.gn = nn.GroupNorm(num_groups=32, num_channels=32)
self.conv2 = nn.Conv3d(32, 3, kernel_size=3, padding=1)
def forward(self, Rs, ts, zs, es, Rd, td, zd):
# Compute rotation and translation grid (w_rt)
w_s2c_rt = compute_rt_warp(Rs, ts, invert=True)
print(f"w_s2c_rt shape: {w_s2c_rt.shape}")
w_c2d_rt = compute_rt_warp(Rd, td, invert=False)
print(f"w_c2d_rt shape: {w_c2d_rt.shape}")
# Compute emotion warping (w_em)
w_s2c_em = self.warp_from_emotion(zs, es)
print(f"w_s2c_em shape: {w_s2c_em.shape}")
w_c2d_em = self.warp_from_emotion(zd, es)
print(f"w_c2d_em shape: {w_c2d_em.shape}")
# Combine rotation/translation and emotion warpings
w_s2c = w_s2c_rt + w_s2c_em
print(f"w_s2c shape: {w_s2c.shape}")
w_c2d = w_c2d_rt + w_c2d_em
print(f"w_c2d shape: {w_c2d.shape}")
return w_s2c, w_c2d
Are you sure about this w_c2d = w_c2d_rt + w_c2d_em
? I cannon find the answer how to connect them correctly
I'm not sure -
I think I need to go through and add this shurg emoji everywhere I'm either confused by whitepaper / or ai the readme - will just list out the shrugs - and others can see if they can solve them. 🤷♂️
going around in circles -
in the paper it says
# which removes the motion data from the volumetric features v𝑠 by mapping them into a canonical coordinate space
I give chatgpt the screenshot - it says ignore the Rs, ts
x = torch.cat((zs, es), dim=1) // 512 + 2048 - then we convole it back to 2048??
# x = torch.cat((Rs, ts, zs, es), dim=1) - 🤷♂️
# The diagram.jpeg shows the inputs to the warping generators are zs+es (expression + appearance) for source-to-canonical, and zd+es for canonical-to-driving. However, the code is concatenating Rs, ts, zs, es (rotation, translation, expression, appearance).
Based on the code and diagram, here are a few observations about the WarpGenerator and warping process:
The output of the WarpGenerator should be a 3D warping field of shape [B, 3, D, H, W], where B is the batch size, D, H, W are the depth, height, and width dimensions. The final Conv3d layer outputs 3 channels, corresponding to the 3D warp field.
The apply_warping_field function should be called after the WarpGenerator, not inside it. The WarpGenerator produces the warp field, which is then used by apply_warping_field to actually warp the volumetric features.
The diagram shows the inputs to the warping generators are zs+es (expression + appearance) for source-to-canonical, and zd+es for canonical-to-driving. However, the code is concatenating Rs, ts, zs, es (rotation, translation, expression, appearance).
To align with the diagram, you could update the WarpGenerator's forward method to:
def forward(self, zs, es):
# Concatenate expression and appearance
x = torch.cat((zs, es), dim=1)
# Rest of the forward pass...
And call it like:
w_em_s2c = self.warp_generator_s2c(zs, es)
w_em_c2d = self.warp_generator_c2d(zd, es)
So in summary, the main steps are:
Let me know if you have any other questions! The code looks quite aligned with the diagram overall, just some minor discrepancies in the WarpGenerator inputs.
sorry - ok - this turned out way more complicated than it should have.
only zd / zs goes through warpgenerator neural net.
this is wrong, right?
x = torch.cat((zs, es), dim=1)
i don't know what to do with global descriptor now. investigating.
As you say - I will reduce es global descriptor to 512 from 2048 - I think layer 3 from resnet50
@Kwentar - what do you think ? Truncated? https://chatgpt.com/share/ec365dbf-3bab-4563-8387-c553122bf254
it’s also possible to go back to 2048 from the 512…but this would be slower.
Update Done but now rebooks error…
Hi again, Unfortunately, There are gaps in the article.
About WarpGenerator:
tuple Rs, Ts, Zs, Es -- is the input of whole warp, Rs an Ts input of Wrt, and Zs and Es input of Wem (diagram 9b has error in input and output, should be es+vs as input and Wem as output).
Wrt -- not a network, your realisation is good here, main question is sizes of grid sample grid_size = 64
, looks like it should be 3x16x32x32 for 256 input image and 3x16x64x64 for 512
The second part Wem is the network, again you have close code to article, but problem is in article here, for sum Wem and Wrt it should have the same size, but if Es and Zs are vectors we will have Wem size 3x16x16x16 for all sizes of images, but should be the same as Wrt, I dont know how to correctly fix it execpt variant that Es and Vs should be feature tensor like 512x2x2 for image size 256 and 512x4x4 for 512. I will try this way, will tell you result later
Thanks heaps - I check first thing in the morning. I found a trick - get chatgpt Omni to rewrite specific parts of the paper then feed that back into Claude Opus. While the parameters don’t pass to warp I can restore back to match the paper - they’re just used in the gbase I just have to move that logic there. It’s frustrating that there’s no other repo on GitHub that’s attempted this kind of warping with resents and everyone is using key points - will be a milestone when this is publicly available.
Once we crack this - gonna circle back to VaSA paper. It’s using canonical key point detector with heat map - it could probably do the job here too - but I wonder if they get the real time speed benefits and high fps due to this approach.
New article from authors so much close to this one: https://arxiv.org/pdf/2404.19110 Looks like Es and Zs dim is 512 :) But new article has the same problems as MegaPortrait
not entirely related but I messed around with Emote paper with speed buckets https://github.com/johndpope/Emote-hack/blob/main/Net.py#L198 - (when aniportrait dropped I abandoned efforts. though that paper didnt really get that emote over the line. has nice trained models + training code ) - 1st July code drop - save the date https://github.com/neeek2303/EMOPortraits
3D warp maybe based on Dense Motion Network. Based on this, I implement BY DOUBLE grid_sample below:
# refer on dense motion network
def compute_rt_warp(rt, v_s, inverse=False):
bs, _, d, h, w = v_s.shape
yaw, pitch, roll = rt['yaw'], rt['pitch'], rt['roll']
yaw = headpose_pred_to_degree(yaw)
pitch = headpose_pred_to_degree(pitch)
roll = headpose_pred_to_degree(roll)
rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
# Invert the transformation matrix if needed
if inverse:
rot_mat = torch.inverse(rot_mat)
rot_mat = rot_mat.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)
rot_mat = rot_mat.repeat(1, d, h, w, 1, 1)
identity_grid = make_coordinate_grid((d, h, w), type=v_s.type())
identity_grid = identity_grid.view(1, d, h, w, 3)
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
t = t.view(t.shape[0], 1, 1, 1, 3)
# rotate
warp_field = torch.bmm(identity_grid.reshape(-1, 1, 3), rot_mat.reshape(-1, 3, 3))
warp_field = warp_field.reshape(identity_grid.shape)
warp_field = warp_field - t
return warp_field
w_rt_s = compute_rt_warp(rt_source, w_em_s, inverse=True)
vc = F.grid_sample(F.grid_sample(v_s, w_rt_s), w_em_s.permute(0, 2, 3, 4, 1))
I got claude.ai to quickly fill out the missing functions https://gist.github.com/johndpope/e39e50ac43380ecc78204cc946c5067c
I throw it into ai in fresh session and ask it to compare compute_rt_warp vs yourcompute_rt_warp2
CAVEAT - I've seen Claude provide clearly wrong answers - so it maybe biased somehow....
DISCLAIMER - I do appreciate code - and any eyes / help - I supply this only for benefit of seeing what Opus can do very quickly. the problem we have is the tensor dimensions - sizes -
your code does want me to inspect the yaw / pitch - i did implement here - but not in this repo because ai seems to think it's out of the box with resnet50. begs reinspecting the values. we definitely get the 3 dimension tensor. https://github.com/johndpope/Emote-hack/blob/7ee104354d52a5461504c27b9f38d269eac86893/Net.py#L832
Based on the analysis of the code and the MegaPortraits paper, the compute_rt_warp function aligns better with the paper's description compared to compute_rt_warp2. Here are the reasons:
Rotation and Translation Representation:
In the paper, the rotation and translation parameters are represented separately as R𝑠/𝑑 (rotation) and t𝑠/𝑑 (translation).
The compute_rt_warp function takes separate rotation and translation parameters as input, which aligns with the paper's representation.
On the other hand, compute_rt_warp2 combines the rotation and translation parameters into a single rt dictionary, which deviates from the paper's representation.
Rotation Matrix Computation:
The paper mentions using rotation matrices to compute the warping transformation.
The compute_rt_warp function explicitly computes the rotation matrix using the get_rotation_matrix function based on the yaw, pitch, and roll angles.
In contrast, compute_rt_warp2 does not explicitly compute the rotation matrix but rather relies on the headpose_pred_to_degree function to convert the predicted head pose to degrees.
Affine Transformation Matrix:
The compute_rt_warp function creates a 4x4 affine transformation matrix by combining the rotation matrix and translation parameters.
This aligns with the paper's description of using an affine transformation matrix to compute the warping transformation.
compute_rt_warp2 does not explicitly create an affine transformation matrix.
Grid Creation and Warping:
The compute_rt_warp function creates a grid of normalized coordinates using F.affine_grid based on the affine transformation matrix.
This grid represents the warping transformations and can be used to warp the volumetric features or other tensors.
compute_rt_warp2 creates an identity grid and applies the rotation matrix and translation separately, which differs from the paper's description.
Overall, the compute_rt_warp function more closely follows the mathematical formulation and representation described in the MegaPortraits paper for computing the rotation/translation warpings (w_rt). It explicitly computes the rotation matrix, creates an affine transformation matrix, and generates a warping grid based on the transformation matrix.
While compute_rt_warp2 also performs warping based on rotation and translation, its implementation deviates from the specific details mentioned in the paper, such as the representation of rotation and translation parameters and the explicit computation of the affine transformation matrix.
Therefore, compute_rt_warp aligns better with the MegaPortraits paper compared to compute_rt_warp2.
What about F.grid_sample(F.grid_sample(v_s, w_rt_s), w_em_s.permute(0, 2, 3, 4, 1)) instead of w_rt_s + w_em_s.permute(0, 2, 3, 4, 1) ? That's important point for my code?
@Jie-zju - thanks - I implement here https://github.com/johndpope/MegaPortrait-hack/blob/main/model.py#L898
here's response from Claude https://gist.github.com/johndpope/7a937b137790588baf9bdd04f0a7b51b
@Kwentar - I redo the resnet18/ resnet50 -> to correctly spit out feature layers + and add an avg pooling layer 8x8 before it was just [1,512] was getting [1.512,16,16] so I added a pooling layer -> now we get [1.512,8,8] so es zs match. https://github.com/johndpope/MegaPortrait-hack/commit/0025a306f50b5e3abcdb4a7fe12ba050f7072324#r142249842 so this is some progress.
@Jie-zju - I print out the pitch / yaw / roll from resnet - seems to check out.
👤 head_pose shape Should print: torch.Size([1, 6]): torch.Size([1, 6])
📐 rotation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
📷 translation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
**pitch: tensor([0.0349], grad_fn=<SelectBackward0>)
yaw: tensor([-0.0755], grad_fn=<SelectBackward0>)
roll: tensor([0.0440], grad_fn=<SelectBackward0>)**
x.shape: torch.Size([1, 3, 256, 256])
self.expression_net shape: torch.Size([1, 512, 8, 8])
👤 head_pose shape Should print: torch.Size([1, 6]): torch.Size([1, 6])
📐 rotation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
📷 translation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
**pitch: tensor([0.0421], grad_fn=<SelectBackward0>)
yaw: tensor([-0.0732], grad_fn=<SelectBackward0>)
roll: tensor([0.0479], grad_fn=<SelectBackward0>)**
x.shape: torch.Size([1, 3, 256, 256])
self.expression_net shape: torch.Size([1, 512, 8, 8])
zs.shape: torch.Size([1, 512, 8, 8])
es.shape: torch.Size([1, 512, 8, 8])
zs_sum.shape: torch.Size([1, 512, 8, 8])
i ripout the adaptivegamma stuff for now.
need to probably switch in the correct spade resblocks https://github.com/johndpope/MegaPortrait-hack/issues/8 cause everything is blowing up with memory problems. RuntimeError: [enforce fail at alloc_cpu.cpp:83] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 115964116992 bytes. Error code 12 (Cannot allocate memory)
to help newbies hitting repo and looking through tickets - i close this conversation and continue the fight with paper + code on new ticket. https://github.com/johndpope/MegaPortrait-hack/issues/11
ambiguities on dimensions are still biggest issue the resnet residual model ( spade resblocks) - is for high res training - so it can wait.
@Kwentar sry, I cannot share my codes