NVlabs / eg3d

Other
3.2k stars 356 forks source link

Training poses for AFHQ dataset #55

Open anuragranj opened 2 years ago

anuragranj commented 2 years ago

Very nice work. How did you generate the camera poses for AFHQ dataset?

LDYang694 commented 1 year ago

Yes, I also want to know. Have you get how to generate AFHQ camera poses or there is any script for that ? Thanks.

MontaEllis commented 1 year ago

Same question

RaymondJiangkw commented 1 year ago

Based on their paper, the authors refer to this repo for 2D landmark detection of cats, and then refer to OpenCV's Perspective-n-Point algorithm implementation for estimating camera poses to transform 3D points to their 2D correspondence.

Thus, it seems that the ground-truth 3D landmarks are missing?

RaymondJiangkw commented 1 year ago

Okay, for those who might be interested in, I found 2 possible solutions.

  1. Following pix2pix3D [Deng et al.], you could use unsup3d to predict the pitch, yaw, roll for cat faces, which are sufficient since the translation vector used in EG3D is calculated based on the rotation matrix.
  2. I trained a 6DRepNet [Hempel et al.] model (checkpoint can be found at here) for cat faces, which directly outputs the rotation matrix. The model was trained on the given AFHQ dataset of EG3D, which contains paired data of the cropped cat face and its camera parameters, without any further modification. I tested a few examples by myself, and found it basically works fine. However, the camera convention of 6DRepNet seems a bit dissimilar with the one of EG3D. The code below shows how to convert the predicted rotation matrix into the camera parameters which can be directly plugged into the EG3D.
    
    device = 'cuda'

Load the model.

import torch import numpy as np from torchvision import transforms

normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )

transformations = transforms.Compose([ transforms.RandomResizedCrop(size=224, scale=(0.8, 1)), transforms.ToTensor(), normalize ])

from model import SixDRepNet model = SixDRepNet(backbone_name='RepVGG-B1g2', backbone_file='', deploy=True, pretrained=False) model.requiresgrad(False) model.load_state_dict(torch.load('./checkpoint.pth')) model.to(device).eval()

Define the camera conversion functions.

@torch.no_grad() def convert_predicted_cam(cam: torch.Tensor): Rot = torch.eye(3, device=cam.device) Rot[0, 0] = 1 Rot[1, 1] = -1 Rot[2, 2] = -1

cam = cam.permute(0, 2, 1)
return cam @ Rot[None, ...]

def gen_pose(rot_mat): rot_mat = np.array(rot_mat).copy() forward = rot_mat[:, 2] translation = forward * -2.7 pose = np.array([ [rot_mat[0, 0], rot_mat[0, 1], rot_mat[0, 2], translation[0]], [rot_mat[1, 0], rot_mat[1, 1], rot_mat[1, 2], translation[1]], [rot_mat[2, 0], rot_mat[2, 1], rot_mat[2, 2], translation[2]], [0, 0, 0, 1], ]) return pose

def gen_label(out: torch.Tensor): pose = gen_pose(convert_predicted_cam(out).squeeze().cpu().numpy()) intrinsics = np.array([ [4.2647, 0.00000000e+00, 0.5], [0.00000000e+00, 4.2647, 0.5], [0.00000000e+00, 0.00000000e+00, 1.00000000e+00] ]) return np.concatenate([pose.reshape(-1), intrinsics.reshape(-1)]).tolist()

from PIL import Image

Height and width of the image should be the same.

image = Image.open('/path/to/cat_face.png').convert('RGB') image = transformations(image).to(device) raw = model(image[None, ...]) cam = gen_label(raw) print('Label:', cam)

luchaoqi commented 1 year ago

Based on their paper, the authors refer to this repo for 2D landmark detection of cats, and then refer to OpenCV's Perspective-n-Point algorithm implementation for estimating camera poses to transform 3D points to their 2D correspondence.

Thus, it seems that the ground-truth 3D landmarks are missing?

Seems the paper made a mistake on citation:

Camera poses were extracted via landmark detection [20] and an open-source Perspective-n-Point algorithm [3]. We augment the dataset with horizontal flips.

it should be Camera poses were extracted via landmark detection [32]

luchaoqi commented 10 months ago

@RaymondJiangkw thanks for the 6DRepNet checkpoint, do you happen to know how to crop in-the-wild cat images like eg3d does for in-the-wild celeb images?

RaymondJiangkw commented 10 months ago

@RaymondJiangkw thanks for the 6DRepNet checkpoint, do you happen to know how to crop in-the-wild cat images like eg3d does for in-the-wild celeb images?

I have tried some approaches. But none of them give reasonable and consistent results for really diverse cat images... The strategy mentioned in their paper, i.e., using landmarks detection and pnp to solve the camera parameters, may be an option for you.