crj1998 / simplified_dfl

0 stars 0 forks source link

torch warpaffine #1

Open crj1998 opened 2 months ago

crj1998 commented 2 months ago

def get_similarity_transform_matrix(from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor:
    """
    Args:
        from_pts, to_pts: b x n x 2

    Returns:
        torch.Tensor: b x 3 x 3
    """
    mfrom = from_pts.mean(dim=1, keepdim=True)  # b x 1 x 2
    mto = to_pts.mean(dim=1, keepdim=True)  # b x 1 x 2

    a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False)  # b
    c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False)  # b

    to_delta = to_pts - mto
    from_delta = from_pts - mfrom
    c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:, :, 1] * from_delta[:, :, 0]).sum([1], keepdim=False)  # b

    a = c1 / a1
    b = c2 / a1
    dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1]  # b
    dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1]  # b

    ones_pl = torch.ones_like(a1)
    zeros_pl = torch.zeros_like(a1)

    return torch.stack([
        a, b, dx,
        -b, a, dy,
        zeros_pl, zeros_pl, ones_pl,
    ], dim=-1).reshape(-1, 3, 3)

def get_face_align_matrix(face_pts: torch.Tensor, target_pts: torch.Tensor):
    target_pts = target_pts.to(face_pts)
    if target_pts.dim() == 2:
        target_pts = target_pts.unsqueeze(0)
    if target_pts.size(0) == 1:
        target_pts = target_pts.broadcast_to(face_pts.shape)
    assert target_pts.shape == face_pts.shape
    return get_similarity_transform_matrix(face_pts, target_pts)

@functools.lru_cache(maxsize=128)
def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]:
    yy, xx = torch.meshgrid(
        torch.arange(h).float(),
        torch.arange(w).float(),
        indexing='ij'
    )
    return yy, xx

def inverted_warp_transform(coords: torch.Tensor, matrix: torch.Tensor):
    """ Inverted tanh-warp function.

    Args:
        coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates.
        matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates 
            from the original image to the aligned yet not-warped image.
        warped_shape (tuple): [height, width].

    Returns:
        torch.Tensor: b x n x 2 (x, y). The original coordinates.
    """

    coords_homo = torch.cat([coords, torch.ones_like(coords[:, :, [0]])], dim=-1)  # b x n x 3

    inv_matrix = torch.linalg.inv(matrix)  # b x 3 x 3
    # inv_matrix = np.linalg.inv(matrix)
    coords_homo = torch.bmm(coords_homo, inv_matrix.permute(0, 2, 1))  # b x n x 3
    return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]

def _forge_grid(
    matrix: torch.Tensor,
    output_shape: Tuple[int, int],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """ Forge transform maps with a given function `fn`.

    Args:
        output_shape (tuple): (b, h, w, ...).
        fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts 
            a bxnx2 array and outputs the transformed bxnx2 array. Both input 
            and output store (x, y) coordinates.

    Note: 
        both input and output arrays of `fn` should store (y, x) coordinates.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each 
            pixel (y, x) or coordinate (x, y),
            `(X[y, x], Y[y, x]) = fn([x, y])`
    """
    batch_size = matrix.size(0)
    device = matrix.device
    h, w, *_ = output_shape
    yy, xx = _meshgrid(h, w)  # h x w
    yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
    xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)

    in_xxyy = torch.stack([xx, yy], dim=-1).reshape([batch_size, h*w, 2])  # (h x w) x 2
    out_xxyy: torch.Tensor = inverted_warp_transform(in_xxyy, matrix)  # (h x w) x 2

    return out_xxyy.reshape(batch_size, h, w, 2)

def make_warp_grid(
    matrix: torch.Tensor,
    warped_shape: Tuple[int, int],
    orig_shape: Tuple[int, int]
):
    """
    Args:
        matrix: bx3x3 matrix.

        warped_shape: The target image shape to transform to.

    Returns:
        torch.Tensor: b x h x w x 2 (x, y).
    """
    orig_h, orig_w, *_ = orig_shape
    w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2)
    grid = _forge_grid(matrix, warped_shape)
    grid = grid / w_h * 2 - 1
    return grid

from torchvision.utils import make_grid, save_image

class IDLoss(nn.Module):
    def __init__(self, resnet_path="/mnt/afs/chenrenjie/workspace/photomaker/faceid/w600k_r50.pth", out_size=112):
        super().__init__()
        target_pts = np.array(
            [
                [38.2946, 51.6963],  # left eye
                [73.5318, 51.5014],  # right eye
                [56.0252, 71.7366],  # nose tip
                [41.5493, 92.3655],  # left mouth corner
                [70.7299, 92.2041],  # right mouth corner
            ],
        )
        old_size = 112
        target_pts = target_pts / old_size * out_size
        # target_pts = torch.from_numpy(target_pts).float()
        self.register_buffer("target_pts", torch.from_numpy(target_pts).float())
        self.iresnet = iresnet50(pretrained=resnet_path)

    @torch.no_grad()
    def similarity(self, images: torch.Tensor, kps: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        _, _, h, w = images.shape
        images = images.float()
        kps = kps * torch.Tensor([h, w]).to(images.device)

        matrix = get_face_align_matrix(kps, self.target_pts)
        grid = make_warp_grid(matrix, orig_shape=(h, w), warped_shape=(112, 112))
        faces = F.grid_sample(images, grid, mode="bilinear", align_corners=False)

        targets = F.interpolate(targets, size=(112, 112), mode="bilinear")
        # save_image(make_grid(torch.cat([targets, (faces+1.0)/2.0]), nrow=4, padding = 4, normalize=False), f"sample_id.jpg")
        target_emb = self.iresnet((targets - 0.5) / 0.5)
        face_emb = self.iresnet(faces)

        cosim = F.cosine_similarity(face_emb, target_emb, dim=-1)

        return cosim.mean().item()

    def forward(self, images: torch.Tensor, kps: torch.Tensor, targets: torch.Tensor, step) -> torch.Tensor:
        _, _, h, w = images.shape
        images = images.float()
        kps = kps * torch.Tensor([h, w]).to(images.device)
        if kps.sum() <= 0.01:
            faces = F.interpolate(images, size=(112, 112), mode="bilinear")
        else:
            matrix = get_face_align_matrix(kps, self.target_pts)
            grid = make_warp_grid(matrix, orig_shape=(h, w), warped_shape=(112, 112))
            faces = F.grid_sample(images, grid, mode="bilinear", align_corners=False)

        with torch.no_grad():
            targets = F.interpolate(targets, size=(112, 112), mode="bilinear")
            # save_image(make_grid(torch.cat([targets, (faces+1.0)/2.0]), nrow=4, padding = 4, normalize=False), f"sample_{step}_id.jpg")
            target_emb = self.iresnet((targets - 0.5) / 0.5)
        face_emb = self.iresnet(faces)

        cosim = F.cosine_similarity(face_emb, target_emb, dim=-1)
        cosim = (1.0 - cosim).mean()

        return cosim

if __name__ == "__main__":

    from typing import List