SamsungLabs / rome

Realistic mesh-based avatars. ECCV 2022
Other
428 stars 41 forks source link

ImportError: cannot import name 'RtGeneEstimator' from 'rt_gene' #20

Open ghost opened 1 year ago

ghost commented 1 year ago

Hello!

Thank you for publishing your wonderful code! I am a student in Japan and would like to use your research for my own study!

However, I am having trouble finding a solution to the error I get when I run train.py (ImportError: cannot import name 'RtGeneEstimator' from 'rt_gene')

Could you please tell me how to resolve the error?

ghost commented 1 year ago

The environment is Windows and in the terminal

python train.py --dataset_name voxceleb2hq_pairs --rome_data_dir data

I typed

GanZhengha commented 1 year ago

have you overcome this? I have the same question.

ghost commented 1 year ago

I still haven't overcome it.

HowieMa commented 1 year ago

Hi, may I ask how to obtain the "train_keys.pkl" and "test_keys.pkl" from this code? Thanks!

johndpope commented 3 months ago

I think this is GazeEstimator https://github.com/Tobias-Fischer/rt_gene/blob/8fb461553e34ac3cc0b5dd3840ec7b3f541a7ffb/rt_gene/src/rt_gene/estimate_gaze_pytorch.py#L15

UPDATE - Claude OPUS - rebuilds code - promising but wrong.... Screenshot from 2024-05-13 12-55-55

I gave the entire code base of rt_gene to claude and asked it to choose to implement rt_gene or just use mediapipe.... Screenshot from 2024-05-13 13-10-21

Screenshot from 2024-05-13 13-10-11


import cv2
import mediapipe as mp
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class GazeLoss(object):
    def __init__(self, device):
        self.device = device
        self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, min_detection_confidence=0.5)

    def forward(self, predicted_gaze, target_gaze, face_image):
        # Convert face image from tensor to numpy array
        face_image = face_image.detach().cpu().numpy().transpose(1, 2, 0)
        face_image = (face_image * 255).astype(np.uint8)

        # Extract eye landmarks using MediaPipe
        results = self.face_mesh.process(cv2.cvtColor(face_image, cv2.COLOR_RGB2BGR))
        if not results.multi_face_landmarks:
            return torch.tensor(0.0).to(self.device)

        eye_landmarks = []
        for face_landmarks in results.multi_face_landmarks:
            left_eye_landmarks = [face_landmarks.landmark[idx] for idx in mp.solutions.face_mesh.FACEMESH_LEFT_EYE]
            right_eye_landmarks = [face_landmarks.landmark[idx] for idx in mp.solutions.face_mesh.FACEMESH_RIGHT_EYE]
            eye_landmarks.append((left_eye_landmarks, right_eye_landmarks))

        # Compute loss for each eye
        loss = 0.0
        for left_eye, right_eye in eye_landmarks:
            # Convert landmarks to pixel coordinates
            h, w = face_image.shape[:2]
            left_eye_pixels = [(int(lm.x * w), int(lm.y * h)) for lm in left_eye]
            right_eye_pixels = [(int(lm.x * w), int(lm.y * h)) for lm in right_eye]

            # Create eye mask
            left_mask = torch.zeros((1, h, w)).to(self.device)
            right_mask = torch.zeros((1, h, w)).to(self.device)
            cv2.fillPoly(left_mask[0], [np.array(left_eye_pixels)], 1.0)
            cv2.fillPoly(right_mask[0], [np.array(right_eye_pixels)], 1.0)

            # Compute gaze loss for each eye
            left_gaze_loss = F.mse_loss(predicted_gaze * left_mask, target_gaze * left_mask)
            right_gaze_loss = F.mse_loss(predicted_gaze * right_mask, target_gaze * right_mask)
            loss += left_gaze_loss + right_gaze_loss

        return loss / len(eye_landmarks)