Zejun-Yang / AniPortrait

AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animation
Apache License 2.0
4.49k stars 566 forks source link

Audio2Mesh Model training #131

Closed agokrani closed 4 months ago

agokrani commented 4 months ago

Hey,

First of all impressive work. I am glad that so many people are contributing to opensource projects like these.

I have a question regarding Audio2Mesh Model. I see that audio_models folder contains the model architecture. However, I am still not sure how this was trained. What are the labels used in training and how media pipe was used to extract the Mesh and aligned with respective audio frames?

0913ktg commented 4 months ago

I also have the same question. There are no clear instructions on the training methods, so I'm struggling with the training. I tried to train using my own code, but the loss function converges in a direction where the mouth in the target image does not move.

agokrani commented 4 months ago

@0913ktg Can you share the training code you are using? Maybe together we can figure out the problem. Also what is the input you are using. How did you extract the face Mesh?

0913ktg commented 4 months ago

Hello @agokrani.

Currently, I am unable to share the training results as they are not yielding expected outcomes. Once I confirm that the results are satisfactory, I will be sure to share them.

I am currently using MediaPipe to extract 3D landmarks for use. Here's a brief overview of the process:

  1. Audio in the .wav format is extracted from video clips, and 3D landmarks corresponding to the frames matched with the audio are prepared.

  2. At present, the 3D landmark of the first frame of the video clip is used as a reference image, and the 3D landmarks of all frames are used as labels for training.

  3. The output of the audio2mesh model is element-wise summed with the reference image 3D landmarks, and then compared with the label using L1 Loss for training.

  4. Moving forward, I plan to experiment with segmenting the audio and frames into specific sizes for training the model.

0913ktg commented 4 months ago

If my approach is wrong or you have ideas, please comment.

FacePoluke commented 4 months ago

We provide the script to extract mesh labels. Please refer to preprocess_dataset.py and modify it to handle your own collected data. For training, we use offset vertices as target, please manually specify a frame with no expression as a neutral face and subtract it from the vertices of each frame to obtain the offset. We align the audio and labels by converting the audio into the number of frames corresponding to the labels, what is done by audio_encoder https://github.com/Zejun-Yang/AniPortrait/blob/8b435af2aeb20c4aef1d2ab746bff64bfe70fe18/src/audio_models/model.py#L59

0913ktg commented 4 months ago

Hello @FacePoluke ,

In your paper, it is mentioned that you used an internal dataset of one hour's duration to train the audio2mesh model. Could you please inform me about the segment size used from the entire video? If it was in frames, how many frames were used?

Additionally, I would appreciate it if you could explain whether the segments were randomly sampled or if there was another method used.

Thank you.

FacePoluke commented 4 months ago

Our data includes about 600 clips, 3-6 seconds per clip. Here are our training dataset script:

import os
import random
import math

import glob
from loguru import logger
from tqdm import tqdm
import librosa
import numpy as np
import torch
from transformers import Wav2Vec2FeatureExtractor
from torch.utils.data import Dataset

def load_ctr(mesh_dir):
    mesh_list = glob.glob(os.path.join(mesh_dir, "*_lmks3d.npy"))
    mesh_list.sort()

    verts_list = [np.load(mesh_file) for mesh_file in mesh_list]
    verts_arr = np.array(verts_list)
    verts_arr = verts_arr.reshape(verts_arr.shape[0], -1)
    return verts_arr

class DataProcessor:
    def __init__(self, sampling_rate, wav2vec_model_path):
        self._processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
        self._sampling_rate = sampling_rate

    def extract_feature(self, audio_path):
        speech_array, sampling_rate = librosa.load(audio_path, sr=self._sampling_rate)
        input_value = np.squeeze(self._processor(speech_array, sampling_rate=sampling_rate).input_values)
        return input_value, sampling_rate

    def extract_feature_and_label(self, audio_path, label_path):
        input_value, sampling_rate = self.extract_feature(audio_path)
        label = load_ctr(label_path)
        return input_value, label, sampling_rate

class EmoClipDataset(Dataset):
    def __init__(
        self,
        config
    ):  
        data_dir = config['data_dir']
        sampling_rate = config['sampling_rate']
        wav2vec_model_path = config['wav2vec_model_path']
        neutral_face = config['neutral_face']
        fps=30

        self.fps = fps
        self._data_preprocessor = DataProcessor(sampling_rate, wav2vec_model_path)
        self._sampling_rate = sampling_rate
        self._features = []
        self._labels = []

        self.neutral_verts = np.load(neutral_face).reshape(-1)

        wav_files = glob.glob(os.path.join(data_dir, "wavs", "*.wav"))
        wav_files.sort()
        for wav_path in tqdm(wav_files):
            filename = os.path.basename(wav_path).split('.')[0]
            ctr_path = os.path.join(data_dir, "mp_info", filename[:-4] + "m" + filename[-3:])
            input_value, label, *_ = self._data_preprocessor.extract_feature_and_label(wav_path, ctr_path)
            label -= self.neutral_verts
            self._features.append(input_value)
            self._labels.append(label)

            try:
                audio_len = math.ceil(len(input_value)/self._sampling_rate*self.fps)
                label_len = label.shape[0]
                assert abs(audio_len - label_len) < 3, "The length of audio and ctr is not matched!"
            except:
                import pdb; pdb.set_trace()

        logger.info('Dataset length: {}'.format(len(self._labels)))

    def __len__(self):
        return len(self._features)

    def __getitem__(self, index):
        options = ["complete", "clip"]
        random_choice = random.choice(options)

        if random_choice == "complete":
            return {
                'audio_feature': self._features[index],
                'ctrs': self._labels[index],
                'label_len': len(self._labels[index])
            }

        elif random_choice == "clip":
            rfeature = self._features[index]
            rlabel = self._labels[index]

            # min: 10
            segment_length = random.randint(10, len(rlabel))
            start_index = random.randint(0, len(rlabel) - segment_length)
            end_index = start_index + segment_length
            feature_start_index = int(start_index * 1 / self.fps * self._sampling_rate)
            feature_end_index = int(end_index * 1 / self.fps * self._sampling_rate)
            rlabel_segment = rlabel[start_index: end_index]
            rfeature_segment = rfeature[feature_start_index: feature_end_index]

            assert abs(math.ceil(len(rfeature_segment)/self._sampling_rate*self.fps) - segment_length) <= 3,\
                f"{math.ceil(len(rfeature_segment)/self._sampling_rate*self.fps)}, {segment_length}"

            return {
                'audio_feature': rfeature_segment,
                'ctrs': rlabel_segment,
                'label_len': len(rlabel_segment)
            }
0913ktg commented 4 months ago

@FacePoluke This is exactly the code I was looking for. Thanks a lot.

agokrani commented 4 months ago

@FacePoluke, thank you so much for all the help here. Since I am just getting started with video models. Please forgive me for my ignorance but just to confirm, what you are suggesting is:

  1. Use the dataset_preprocess.py script to extract mesh labels.
  2. Once extracted, I need to write dataset preprocessor similar to yours depending on the collected dataset and this dataset can be used to train Audio2Mesh model. Is my understanding correct?

One more thing, the dataset preprocessor also converts the audio into number of frames corresponding to the labels, right? Lastly, why your dataset randomly choose between clip and complete? Is there any specific reason for doing this?

TaekyungKi commented 4 months ago

Hi @FacePoluke Thanks for the code you provided. I have more question on your Aud2Mesh training.

As your paper and the code, you trained audio to mesh offests by substracting neutral face mesh. The face mesh is from a single speaker with frontal camera view. The neural face is same throughout whole dataset.

  1. Why didnt you train it on in-the-wild multi-speaker dataset, such as, voxceleb2, celebv-hq? I think it can improve the audio generalization.

  2. Why did you train it with neurtal face substraction? I think randomly choosen mesh (i.e. substrating the random mesh when batch constuction) can improve the robustness to reference image (able to generate from non-neutral reference image).

Thank you in advance. If I missed something, please let me know!

junwenxiong commented 4 months ago

@TaekyungKi I'm curious. What is the neurtal face substraction? Could you please explain it? Thanks!

TaekyungKi commented 4 months ago

@junwenxiong In the provied dataloader class EmoClipDataset above,

  1. It loads and initializes neutral face mesh vertices: self.neutral_verts = np.load(neutral_face).reshape(-1)
  2. And it substracts all labels with the neutral face label -= self.neutral_verts meaning that authors employ GT labels of "relative" mesh offsets rather than absolute meshes.

The model, therefore, is trained to predict the offset vertices for abolute mesh by compensating the neutral mesh self.neutral_verts. I think offset prediction is more temporally consistent than the absolute prediction (i.e. without 2.). My point is why did author uses only neutral face vertices. I think random mesh can be used for offset prediction.

0913ktg commented 4 months ago

Hi @TaekyungKi . I am currently struggling to train the audio2mesh model with Korean data. Have you succeeded in training with a custom dataset by any chance?