Closed johndpope closed 5 months ago
I feel the current warping code is almost correct by using affine_grid
. And you can supply an identity affine matrix to affine_grid
to obtain a starting meshgrid instead of using linspace
and such. Something like:
# 2D example with batch_size 1
theta = torch.tensor([[[1,0,0], [0,1,0]]], dtype=torch.float32)
grid = F.affine_grid(theta, (1, 16, 64, 64), align_corners=True)
However I think we don't need such a "base grid".
The unnatural part in the current implementation is the normalization of w
. affine_grid
guarantees a grid in [-1, 1], while I can see that w_em
and the sum may reach outside [-1, 1], I feel we should not manually perform any normalization. The network should learn that w_em
needs to be small. That is we simply add the two warpings: w = w_em + w_rt
and pass the result to grid_sample
. The affine_grid
already includes the "base grid" so there is no need to add a "base grid".
In short the apply_warping_field
method can be replaced by v_canonical = F.grid_sample(v, w, mode='trilinear', padding_mode='border', align_corners=True)
sorry pushed some broken code - fixing now.
sorry pushed some broken code - fixing now.
I think you need to keep interpolate
in the apply method. According to the diagram in the paper, warp should not change the size of vs
. Therefore w
should be scaled to the same size as vs
.
I'm looking at the G3D
module, which the current implementation is so different from what the paper shows. And the paper's diagram is also confusing - why is the depth upscaled and then downscaled during the downsampling phase?
i dig through the oneshot freeview paper - and they have some flow illustrations https://ar5iv.labs.arxiv.org/html/2011.15126
I add to forward pass in gbase -
one of the key things in that paper - is they limit the keypoints to 20. because we are in resnet land - and no keypoints - not sure how we can throttle the amount of data in the fields.
UPDATE
I add some sampling - and make +ve / -ve different colours.
when I worked on this paper - https://github.com/johndpope/Emote-hack I used some code to determine head position / pitch / yaw / roll - I add some more debug into the https://github.com/johndpope/MegaPortrait-hack/blob/main/model.py#L1056
I add this to help illustrate things
I'm happy to chop out code - or accept a PR.
I might be wrong but what I understand is that the "warp" field generated by affine_grid
is the sampling grid, not the actual vector field. To obtain the vector field for visualization, we need to subtract from the warp tensor a "base grid" that is evenly sampling the space.
Another trap is that the last dim of the output of affine_grid
, in our case 3, stores 3 values not in the order of (D, H, W), but in the order of (H, W, D). I don't know why it's designed this way, but it took me quite a while to figure out.
To correctly visualize the result of affine_grid
as vector field, I use something like:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
import torch
import numpy as np
import torch.nn.functional as F
k = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]],
dtype=torch.float32)
base = F.affine_grid(k.unsqueeze(0), [1, 1, 2, 3, 4], align_corners=True)
k = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0]],
dtype=torch.float32) # rotate
grid = F.affine_grid(k.unsqueeze(0), [1, 1, 2, 3, 4], align_corners=True)
grid = grid - base
grid = grid[0]
D, H, W, _ = grid.shape
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
k, j, i = np.meshgrid(
np.arange(0, D, 1),
np.arange(0, H, 1),
np.arange(0, W, 1),
indexing="ij",
)
u = grid[..., 0].numpy()
v = grid[..., 1].numpy()
w = grid[..., 2].numpy()
ax.quiver(k, j, i, w, v, u, length=0.3)
plt.show()
There is an easier way I used to verify the correctness of the warping code. I reduced everything to 2D then use an image to test. Generate the affine grid using a known operation like rotation 90 + translation 0.2, run through grid_sample
then output the image, it should work as expected (rotated 90, shifted so part of the output is black). Since we only use the 'w' tensor in grid_sample
, we don't have to figure out its exact layout.
nice.
https://github.com/OpenTalker/video-retalking/blob/d32e8e58248255e2d243eeaf3cba545dbe505ca8/utils/flow_util.py#L4
https://github.com/OpenTalker/video-retalking/blob/d32e8e58248255e2d243eeaf3cba545dbe505ca8/models/DNet.py#L88