johndpope / MegaPortrait-hack

Using Claude Opus to reverse engineer code from MegaPortraits: One-shot Megapixel Neural Head Avatars
https://arxiv.org/abs/2207.07621
39 stars 6 forks source link

Implement pre-trained gaze estimation model #25

Open johndpope opened 2 weeks ago

johndpope commented 2 weeks ago

Plus blink loss

Note that you'll need to preprocess your data to provide the left eye image, right eye image, 2D keypoints, target gaze, and target blink for each sample during training


'''
In this updated code:

We define a GazeBlinkLoss module that combines the gaze and blink prediction tasks.
The module consists of a backbone network (VGG-16), a keypoint network, a gaze prediction head, and a blink prediction head.
The backbone network is used to extract features from the left and right eye images separately. The features are then summed to obtain a combined eye representation.
The keypoint network takes the 2D keypoints as input and produces a latent vector of size 64.
The gaze prediction head takes the concatenated eye features and keypoint features as input and predicts the gaze direction.
The blink prediction head takes only the eye features as input and predicts the blink probability.
The gaze loss is computed using both MAE and MSE losses, weighted by w_mae and w_mse, respectively.
The blink loss is computed using binary cross-entropy loss.
The total loss is the sum of the gaze loss and blink loss.

To train this model, you can follow the training procedure you described:

Use Adam optimizer with the specified hyperparameters.
Train for 60 epochs with a batch size of 64.
Use the one-cycle learning rate schedule.
Treat the predictions from RT-GENE and RT-BENE as ground truth.

Note that you'll need to preprocess your data to provide the left eye image, right eye image, 2D keypoints, target gaze, and target blink for each sample during training.
This code provides a starting point for aligning the MediaPipe-based gaze and blink loss with the approach you described. You may need to make further adjustments based on your specific dataset and requirements.

'''
import cv2
import mediapipe as mp
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

class GazeBlinkLoss(nn.Module):
    def __init__(self, device, w_mae=15, w_mse=10):
        super(GazeBlinkLoss, self).__init__()
        self.device = device
        self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, min_detection_confidence=0.5)
        self.w_mae = w_mae
        self.w_mse = w_mse

        self.backbone = self._create_backbone()
        self.keypoint_net = self._create_keypoint_net()
        self.gaze_head = self._create_gaze_head()
        self.blink_head = self._create_blink_head()

    def _create_backbone(self):
        model = torchvision.models.vgg16(pretrained=True)
        model.classifier = nn.Sequential(*list(model.classifier.children())[:1])
        return model

    def _create_keypoint_net(self):
        return nn.Sequential(
            nn.Linear(136, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU()
        )

    def _create_gaze_head(self):
        return nn.Sequential(
            nn.Linear(320, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

    def _create_blink_head(self):
        return nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, left_eye, right_eye, keypoints, target_gaze, target_blink):
        # Extract eye features using the backbone
        left_features = self.backbone(left_eye)
        right_features = self.backbone(right_eye)
        eye_features = left_features + right_features

        # Extract keypoint features
        keypoint_features = self.keypoint_net(keypoints)

        # Predict gaze
        gaze_input = torch.cat((eye_features, keypoint_features), dim=1)
        predicted_gaze = self.gaze_head(gaze_input)

        # Predict blink
        predicted_blink = self.blink_head(eye_features)

        # Compute gaze loss
        gaze_mae_loss = nn.L1Loss()(predicted_gaze, target_gaze)
        gaze_mse_loss = nn.MSELoss()(predicted_gaze, target_gaze)
        gaze_loss = self.w_mae * gaze_mae_loss + self.w_mse * gaze_mse_loss

        # Compute blink loss
        blink_loss = nn.BCEWithLogitsLoss()(predicted_blink, target_blink)

        # Total loss
        total_loss = gaze_loss + blink_loss

        return total_loss, predicted_gaze, predicted_blink
JaLnYn commented 5 days ago

Is there a specific library you had in mind for gaze tracking? I found a couple but I don't think they use deep learning techiniques. Not sure they utilize the gpu either but I've tested a few and they work pretty well. (on hd webcam idk about 224x224)

JaLnYn commented 5 days ago

Here's the gaze loss I've implemented for my attempt at megaportrait reproduction https://github.com/JaLnYn/talkinghead/blob/main/src/train/loss.py. Although I'm not sure if it's 100% correct. It's based off of https://github.com/glefundes/mobile-face-gaze

johndpope commented 1 day ago

cool - i got the multiscale patch discriminator drafted here - https://github.com/johndpope/MegaPortrait-hack/pull/47

running through training on many heads - it's a fight i can't win without gpu compute. I want to instead focus on disentangling Selena gomez video.

https://github.com/johndpope/VASA-1-hack/issues/16

i want to be able to do this - adjust face pitch . yaw / etc https://github.com/RameenAbdal/StyleFlow/tree/webui