Open crj1998 opened 2 months ago
def get_similarity_transform_matrix(from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor: """ Args: from_pts, to_pts: b x n x 2 Returns: torch.Tensor: b x 3 x 3 """ mfrom = from_pts.mean(dim=1, keepdim=True) # b x 1 x 2 mto = to_pts.mean(dim=1, keepdim=True) # b x 1 x 2 a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False) # b c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False) # b to_delta = to_pts - mto from_delta = from_pts - mfrom c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:, :, 1] * from_delta[:, :, 0]).sum([1], keepdim=False) # b a = c1 / a1 b = c2 / a1 dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1] # b dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1] # b ones_pl = torch.ones_like(a1) zeros_pl = torch.zeros_like(a1) return torch.stack([ a, b, dx, -b, a, dy, zeros_pl, zeros_pl, ones_pl, ], dim=-1).reshape(-1, 3, 3) def get_face_align_matrix(face_pts: torch.Tensor, target_pts: torch.Tensor): target_pts = target_pts.to(face_pts) if target_pts.dim() == 2: target_pts = target_pts.unsqueeze(0) if target_pts.size(0) == 1: target_pts = target_pts.broadcast_to(face_pts.shape) assert target_pts.shape == face_pts.shape return get_similarity_transform_matrix(face_pts, target_pts) @functools.lru_cache(maxsize=128) def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]: yy, xx = torch.meshgrid( torch.arange(h).float(), torch.arange(w).float(), indexing='ij' ) return yy, xx def inverted_warp_transform(coords: torch.Tensor, matrix: torch.Tensor): """ Inverted tanh-warp function. Args: coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates. matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates from the original image to the aligned yet not-warped image. warped_shape (tuple): [height, width]. Returns: torch.Tensor: b x n x 2 (x, y). The original coordinates. """ coords_homo = torch.cat([coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3 inv_matrix = torch.linalg.inv(matrix) # b x 3 x 3 # inv_matrix = np.linalg.inv(matrix) coords_homo = torch.bmm(coords_homo, inv_matrix.permute(0, 2, 1)) # b x n x 3 return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]] def _forge_grid( matrix: torch.Tensor, output_shape: Tuple[int, int], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forge transform maps with a given function `fn`. Args: output_shape (tuple): (b, h, w, ...). fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts a bxnx2 array and outputs the transformed bxnx2 array. Both input and output store (x, y) coordinates. Note: both input and output arrays of `fn` should store (y, x) coordinates. Returns: Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each pixel (y, x) or coordinate (x, y), `(X[y, x], Y[y, x]) = fn([x, y])` """ batch_size = matrix.size(0) device = matrix.device h, w, *_ = output_shape yy, xx = _meshgrid(h, w) # h x w yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device) xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device) in_xxyy = torch.stack([xx, yy], dim=-1).reshape([batch_size, h*w, 2]) # (h x w) x 2 out_xxyy: torch.Tensor = inverted_warp_transform(in_xxyy, matrix) # (h x w) x 2 return out_xxyy.reshape(batch_size, h, w, 2) def make_warp_grid( matrix: torch.Tensor, warped_shape: Tuple[int, int], orig_shape: Tuple[int, int] ): """ Args: matrix: bx3x3 matrix. warped_shape: The target image shape to transform to. Returns: torch.Tensor: b x h x w x 2 (x, y). """ orig_h, orig_w, *_ = orig_shape w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2) grid = _forge_grid(matrix, warped_shape) grid = grid / w_h * 2 - 1 return grid from torchvision.utils import make_grid, save_image class IDLoss(nn.Module): def __init__(self, resnet_path="/mnt/afs/chenrenjie/workspace/photomaker/faceid/w600k_r50.pth", out_size=112): super().__init__() target_pts = np.array( [ [38.2946, 51.6963], # left eye [73.5318, 51.5014], # right eye [56.0252, 71.7366], # nose tip [41.5493, 92.3655], # left mouth corner [70.7299, 92.2041], # right mouth corner ], ) old_size = 112 target_pts = target_pts / old_size * out_size # target_pts = torch.from_numpy(target_pts).float() self.register_buffer("target_pts", torch.from_numpy(target_pts).float()) self.iresnet = iresnet50(pretrained=resnet_path) @torch.no_grad() def similarity(self, images: torch.Tensor, kps: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: _, _, h, w = images.shape images = images.float() kps = kps * torch.Tensor([h, w]).to(images.device) matrix = get_face_align_matrix(kps, self.target_pts) grid = make_warp_grid(matrix, orig_shape=(h, w), warped_shape=(112, 112)) faces = F.grid_sample(images, grid, mode="bilinear", align_corners=False) targets = F.interpolate(targets, size=(112, 112), mode="bilinear") # save_image(make_grid(torch.cat([targets, (faces+1.0)/2.0]), nrow=4, padding = 4, normalize=False), f"sample_id.jpg") target_emb = self.iresnet((targets - 0.5) / 0.5) face_emb = self.iresnet(faces) cosim = F.cosine_similarity(face_emb, target_emb, dim=-1) return cosim.mean().item() def forward(self, images: torch.Tensor, kps: torch.Tensor, targets: torch.Tensor, step) -> torch.Tensor: _, _, h, w = images.shape images = images.float() kps = kps * torch.Tensor([h, w]).to(images.device) if kps.sum() <= 0.01: faces = F.interpolate(images, size=(112, 112), mode="bilinear") else: matrix = get_face_align_matrix(kps, self.target_pts) grid = make_warp_grid(matrix, orig_shape=(h, w), warped_shape=(112, 112)) faces = F.grid_sample(images, grid, mode="bilinear", align_corners=False) with torch.no_grad(): targets = F.interpolate(targets, size=(112, 112), mode="bilinear") # save_image(make_grid(torch.cat([targets, (faces+1.0)/2.0]), nrow=4, padding = 4, normalize=False), f"sample_{step}_id.jpg") target_emb = self.iresnet((targets - 0.5) / 0.5) face_emb = self.iresnet(faces) cosim = F.cosine_similarity(face_emb, target_emb, dim=-1) cosim = (1.0 - cosim).mean() return cosim if __name__ == "__main__": from typing import List