MRzzm / DINet

The source code of "DINet: deformation inpainting network for realistic face visually dubbing on high resolution video."
984 stars 174 forks source link

请教 Syncnet Training 代码是否正确 #40

Closed lcc157 closed 1 year ago

lcc157 commented 1 year ago
from models.Syncnet import SyncNetPerception,SyncNet
from config.config import DINetTrainingOptions
from sync_batchnorm import convert_model

from torch.utils.data import DataLoader
from dataset.dataset_DINet_syncnet import DINetDataset

from utils.training_utils import get_scheduler, update_learning_rate,GANLoss

import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import os
import torch.nn.functional as F

if __name__ == "__main__":
    # load config
    opt = DINetTrainingOptions().parse_args()
    random.seed(opt.seed)
    np.random.seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)
    # init network

    net_lipsync = SyncNet(15,29,128).cuda()

    criterionMSE = nn.BCELoss().cuda()
    # set scheduler
    # set label of syncnet perception loss
    real_tensor = torch.tensor(1.0).cuda()

    # setup optimizer
   # optimizer_s = optim.Adam(net_lipsync.parameters(), lr=opt.lr_g)
    optimizer_s = optim.Adamax(net_lipsync.parameters(), lr=opt.lr_g)

    # set scheduler
    net_s_scheduler = get_scheduler(optimizer_s, opt.non_decay, opt.decay)

    # load training data
    train_data = DINetDataset(opt.train_data,opt.augment_num,opt.mouth_region_size)
    training_data_loader = DataLoader(dataset=train_data,  batch_size=opt.batch_size, shuffle=True,drop_last=True,num_workers=12)
    train_data_length = len(training_data_loader)

    # load training data
    test_data = DINetDataset(opt.test_data,opt.augment_num,opt.mouth_region_size)
    test_data_loader = DataLoader(dataset=test_data,  batch_size=1, shuffle=True,drop_last=True,num_workers=12)
    test_data_length = len(test_data_loader)

    min_loss = 100
    # start train
    for epoch in range(opt.start_epoch, opt.non_decay+opt.decay+1):
        net_lipsync.train()
        for iteration, data in enumerate(training_data_loader):
            # forward
            optimizer_s.zero_grad()
            source_clip, deep_speech_full, y = data
            source_clip = torch.cat(torch.split(source_clip, 1, dim=1), 0).squeeze(1).float().cuda()
            source_clip = torch.cat(torch.split(source_clip, opt.batch_size, dim=0), 1).cuda()
            deep_speech_full = deep_speech_full.float().cuda()

            y = y.cuda()
            ## sync perception loss
            source_clip_mouth = source_clip[:, :, train_data.radius:train_data.radius + train_data.mouth_region_size,
            train_data.radius_1_4:train_data.radius_1_4 + train_data.mouth_region_size]
            sync_score = net_lipsync(source_clip_mouth, deep_speech_full)        

            loss_sync = criterionMSE(sync_score.unsqueeze(1), y)

            loss_sync.backward()
            optimizer_s.step()

            print(
                "===> Epoch[{}]({}/{}):  Loss_Sync: {:.4f} lr_g = {:.7f} ".format(
                    epoch, iteration, len(training_data_loader), float(loss_sync) ,
                    optimizer_s.param_groups[0]['lr']))

        update_learning_rate(net_s_scheduler, optimizer_s)

        # checkpoint
        if epoch %  opt.checkpoint == 0 :
            if not os.path.exists(opt.result_path):
                os.makedirs(opt.result_path)
            model_out_path = os.path.join(opt.result_path, 'netS_model_epoch_{}.pth'.format(epoch))
            states = {
                'epoch': epoch + 1,
                'state_dict': {'net': net_lipsync.state_dict()},
                'optimizer': {'net': optimizer_s.state_dict()}
            }
            torch.save(states, model_out_path)
            print("Checkpoint saved to {}".format(epoch))
        if epoch %  opt.stop_checkpoint == 0:
            break
davidmartinrius commented 1 year ago

@lcc157 asked the same question in another project. Here is a response of @primepake https://github.com/primepake/wav2lip_288x288/issues/43#issuecomment-1616430268

@primepake proposes to use BCE instead of MSE.

The BCE (Binary Cross Entropy) Loss and MSE (Mean Squared Error) Loss are two commonly used loss functions in machine learning and deep learning problems. Although both are used to measure the discrepancy between predicted outputs and actual outputs, there are key differences between them.

  1. Application: BCE Loss is primarily used in binary classification problems, where the goal is to predict the probability of belonging to a specific class or category (0 or 1). On the other hand, MSE Loss is commonly used in regression problems, where the goal is to predict a continuous numerical value.

  2. Output Range: BCE Loss produces a loss value for each individual example in a range from 0 to ∞. The closer the value is to zero, the better the prediction. In contrast, MSE Loss measures the average of the squared errors between predictions and actual outputs. The result is an unbounded loss value, where values closer to zero indicate better predictions.

  3. Sensitivity to Outliers: MSE Loss is more sensitive to outliers than BCE Loss. This is because squared errors amplify the contribution of outliers to the loss function. In contrast, BCE Loss treats each individual prediction independently and is not as affected by extreme values.

  4. Gradients: Since BCE Loss is commonly used in binary classification problems, the loss function is derived in a way that gradients are more stable and easier to compute compared to MSE Loss. This can be beneficial during the model training process.

In summary, BCE Loss and MSE Loss are different loss functions used for different types of problems. BCE Loss is suitable for binary classification, while MSE Loss is more appropriate for regression problems. The choice of the loss function depends on the problem type and the data characteristics.

ghost commented 1 year ago

yeah! I mean in this problem BCE should be better than MSE

lcc157 commented 1 year ago

@davidmartinrius @primepake Thanks!!, So,When I run python train_DINet_clip.py ( Clip training stage. ), should I also use BCE?

codersun123 commented 1 year ago

@lcc157 我觉得train_DINet_clip.py 不需要修改,可以直接使用训练好的Syncnet 训练的代码应该存在下面的问题 1.需要修改数据集的取数方式,模仿wav2lip的方式

ghost commented 1 year ago

@davidmartinrius @primepake Thanks!!, So,When I run python train_DINet_clip.py ( Clip training stage. ), should I also use BCE?

you can run this with BCE loss the same with hq_wav2lip_train.py

Psarpei commented 1 year ago

Why not using cosine-similarity instead ?

davidmartinrius commented 1 year ago

Well, BCE Loss is suitable for binary classification problems, MSE Loss is appropriate for regression tasks, and Cosine Similarity is useful for measuring similarity between vectors.

But,

  1. what kind of problem is trying to solve this model?
  2. how to determine which type of loss is more appropiate?
ghost commented 1 year ago

Why not using cosine-similarity instead ?

they used cosine, for basically:

codersun123 commented 1 year ago

At present, I am training syncnet on the Chinese training set, and I use the consumer set graphics card 3060 ti. Now I encounter a problem that confuses me, and the loss does not decrease significantly at present, hovering around 0.69. May I ask what advice you can give me?@primepake

1059692261 commented 1 year ago

In my situation, the loss(MSE) of syncnet decreases dramatically to 1e-3 level in the first epoch, and I think something must go wrong. Any advices would be appreciated. BTW, using BCE would report an error in loss calculation: Assertion input_val >= zero && input_val <= one failed.

wning13 commented 1 year ago

In an article I read about wav2lip, it mentioned that during the training of syncnet, it is important to have a sufficient number of speakers in the training data, otherwise the loss function remains around 0.69. @codersun123

codersun123 commented 1 year ago

In an article I read about wav2lip, it mentioned that during the training of syncnet, it is important to have a sufficient number of speakers in the training data, otherwise the loss function remains around 0.69. @codersun123

My current test set has 120 people with 43 hours of content, and my personal feeling is not that this is the reason. I'm going to try the loss situation using wav2lip with this dataset

davidmartinrius commented 1 year ago

In an article I read about wav2lip, it mentioned that during the training of syncnet, it is important to have a sufficient number of speakers in the training data, otherwise the loss function remains around 0.69. @codersun123

My current test set has 120 people with 43 hours of content, and my personal feeling is not that this is the reason. I'm going to try the loss situation using wav2lip with this dataset

Well, you said you are using a RTX 3060 TI. You are probably using a small batch size due to low vram. When the batch size is too small due to GPU memory (VRAM) limitations, it can have a negative impact on model convergence. Here are some ways in which a very small batch size can affect convergence:

Noisier gradient estimates: with a small batch size, the gradient computed at each iteration is based on a limited set of training examples. This can make the gradient estimates noisier and less accurate, which can make it difficult for the model to find an optimal direction to update the weights. As a result, the model may take longer to converge or may not converge at all.

Less stable weight updates: With a small batch size, weight updates are performed more frequently. This can make updates more volatile and less stable. Instead of having a more consistent and reliable update direction based on a larger set of examples, the model can be affected by random fluctuations in the individual training examples. This can lead to oscillations in the loss function and make convergence more difficult.

Lower computational efficiency: A small batch size can make training less computationally efficient. This is because the model must perform more iterations to process all the training examples, which can lead to longer training time overall. Also, if the batch size is too small compared to the processing power of the GPU, there may be additional overhead due to frequent data transfers between main memory and VRAM.

Maybe you need a graphics card with more vram than just between 8 and 12GB. Have you tried a RTX A6000 or A100?

lcc157 commented 1 year ago

@codersun123 你好你用的哪个中文数据集,自制的吗

Saksham209 commented 1 year ago

In my situation, the loss(MSE) of syncnet decreases dramatically to 1e-3 level in the first epoch, and I think something must go wrong. Any advices would be appreciated. BTW, using BCE would report an error in loss calculation: Assertion input_val >= zero && input_val <= one failed.

@1059692261 I have encountered a similar issue while training syncnet with HDTF dataset for finetuning DINet. I am using the implementation provided by @Icc157. Any ideas as to why this is the case?

Saksham209 commented 1 year ago

@primepake @davidmartinrius could you please share a bit more detailed implementation of syncnet training for DINet usecase. I am new to this and any help in this matter would be really helpful. Thank you

codersun123 commented 1 year ago

@codersun123 你好你用的哪个中文数据集,自制的吗

自制了一个,视频来源是bilibili上的演讲视频,我参考了你分享的和wav2lip的代码,并改成了BCEloss,如果有什么可以分享的话可以与我联系,codesun1234567@gmail.com

lcc157 commented 1 year ago

@codersun123 你好你用的哪个中文数据集,自制的吗

自制了一个,视频来源是bilibili上的演讲视频,我参考了你分享的和wav2lip的代码,并改成了BCEloss,如果有什么可以分享的话可以与我联系,codesun1234567@gmail.com

对中文效果好吗,我最近还在做数据,这是我参考wav2lip对于BCE需要的处理。我试了用1小时训练数据去训练syncnet,MSE和BCE都降至1e-3,是数据量太少吗还是读取data不应该随机取数据`
def forward(self, image,audio): image_embedding = self.face_encoder(image) audio_embedding = self.audio_encoder(audio) audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) face_embedding = image_embedding.view(image_embedding.size(0), -1)

    audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
    face_embedding = F.normalize(face_embedding, p=2, dim=1)
    out_score= nn.functional.cosine_similarity(audio_embedding, face_embedding)
    return out_score`
Psarpei commented 1 year ago

In an article I read about wav2lip, it mentioned that during the training of syncnet, it is important to have a sufficient number of speakers in the training data, otherwise the loss function remains around 0.69. @codersun123

Can you provide the article?

codersun123 commented 1 year ago

@codersun123 你好你用的哪个中文数据集,自制的吗

自制了一个,视频来源是bilibili上的演讲视频,我参考了你分享的和wav2lip的代码,并改成了BCEloss,如果有什么可以分享的话可以与我联系,codesun1234567@gmail.com

对中文效果好吗,我最近还在做数据,这是我参考wav2lip对于BCE需要的处理。我试了用1小时训练数据去训练syncnet,MSE和BCE都降至1e-3,是数据量太少吗还是读取data不应该随机取数据` def forward(self, image,audio): image_embedding = self.face_encoder(image) audio_embedding = self.audio_encoder(audio) audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) face_embedding = image_embedding.view(image_embedding.size(0), -1)

    audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
    face_embedding = F.normalize(face_embedding, p=2, dim=1)
    out_score= nn.functional.cosine_similarity(audio_embedding, face_embedding)
    return out_score`

我的loss到了0.69就降不下去了,目前还没有解决这个问题

Psarpei commented 1 year ago

The problem with the overfitting with the provided code from @lcc157 is that in line 35 there is real_tensor = torch.tensor(1.0).cuda() Because of that syncnet learns to always predict its in sync but the dataset needs to be updated in the way that it returns the value 1 (in sync) and 0 (not in sync). Therefore the dataset needs to create with 50% chance also a "not in sync" example where the audio and video is not alligned.

lcc157 commented 1 year ago

@Psarpei

The problem with the overfitting with the provided code from @lcc157 is that in line 35 there is real_tensor = torch.tensor(1.0).cuda() Because of that syncnet learns to always predict its in sync but the dataset needs to be updated in the way that it returns the value 1 (in sync) and 0 (not in sync). Therefore the dataset needs to create with 50% chance also a "not in sync" example where the audio and video is not alligned.

Can you provide more detailed code? Thanks

lcc157 commented 1 year ago

@codersun123 你好你用的哪个中文数据集,自制的吗

自制了一个,视频来源是bilibili上的演讲视频,我参考了你分享的和wav2lip的代码,并改成了BCEloss,如果有什么可以分享的话可以与我联系,codesun1234567@gmail.com

对中文效果好吗,我最近还在做数据,这是我参考wav2lip对于BCE需要的处理。我试了用1小时训练数据去训练syncnet,MSE和BCE都降至1e-3,是数据量太少吗还是读取data不应该随机取数据` def forward(self, image,audio): image_embedding = self.face_encoder(image) audio_embedding = self.audio_encoder(audio) audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) face_embedding = image_embedding.view(image_embedding.size(0), -1)

    audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
    face_embedding = F.normalize(face_embedding, p=2, dim=1)
    out_score= nn.functional.cosine_similarity(audio_embedding, face_embedding)
    return out_score`

我的loss到了0.69就降不下去了,目前还没有解决这个问题

我的也是一直在0.69,请问你解决了吗?

wning13 commented 1 year ago

Can you provide the article?

https://zhuanlan.zhihu.com/p/613996840. However, there are no more details in the article. @Psarpei

Saksham209 commented 1 year ago

Can you provide the article?

https://zhuanlan.zhihu.com/p/613996840. However, there are no more details in the article. @Psarpei

Does this also mean that we cannot use dataset of size of HDTF dataset used by DINET for training syncnet. How big of a dataset should we be considering for training it?

wning13 commented 1 year ago

@Psarpei

The problem with the overfitting with the provided code from @lcc157 is that in line 35 there is real_tensor = torch.tensor(1.0).cuda() Because of that syncnet learns to always predict its in sync but the dataset needs to be updated in the way that it returns the value 1 (in sync) and 0 (not in sync). Therefore the dataset needs to create with 50% chance also a "not in sync" example where the audio and video is not alligned.

Can you provide more detailed code? Thanks

可以参考下Wav2Lip的实现 https://github.com/Rudrabha/Wav2Lip/blob/0ec01ce80a84ede143567adb63ee773adf4e2668/color_syncnet_train.py#L82C12-L82C12

lcc157 commented 1 year ago

@Psarpei

The problem with the overfitting with the provided code from @lcc157 is that in line 35 there is real_tensor = torch.tensor(1.0).cuda() Because of that syncnet learns to always predict its in sync but the dataset needs to be updated in the way that it returns the value 1 (in sync) and 0 (not in sync). Therefore the dataset needs to create with 50% chance also a "not in sync" example where the audio and video is not alligned.

Can you provide more detailed code? Thanks

可以参考下Wav2Lip的实现 https://github.com/Rudrabha/Wav2Lip/blob/0ec01ce80a84ede143567adb63ee773adf4e2668/color_syncnet_train.py#L82C12-L82C12

Thaks, i have done.

wning13 commented 1 year ago

@codersun123 你好你用的哪个中文数据集,自制的吗

自制了一个,视频来源是bilibili上的演讲视频,我参考了你分享的和wav2lip的代码,并改成了BCEloss,如果有什么可以分享的话可以与我联系,codesun1234567@gmail.com

对中文效果好吗,我最近还在做数据,这是我参考wav2lip对于BCE需要的处理。我试了用1小时训练数据去训练syncnet,MSE和BCE都降至1e-3,是数据量太少吗还是读取data不应该随机取数据` def forward(self, image,audio): image_embedding = self.face_encoder(image) audio_embedding = self.audio_encoder(audio) audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) face_embedding = image_embedding.view(image_embedding.size(0), -1)

    audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
    face_embedding = F.normalize(face_embedding, p=2, dim=1)
    out_score= nn.functional.cosine_similarity(audio_embedding, face_embedding)
    return out_score`

我的loss到了0.69就降不下去了,目前还没有解决这个问题

我的也是一直在0.69,请问你解决了吗?

你们解决了吗,我用自己的数据直接训练也是卡在了0.69,先加载训练好的英文模型继续训练是可以收敛。

Saksham209 commented 1 year ago

@Psarpei

The problem with the overfitting with the provided code from @lcc157 is that in line 35 there is real_tensor = torch.tensor(1.0).cuda() Because of that syncnet learns to always predict its in sync but the dataset needs to be updated in the way that it returns the value 1 (in sync) and 0 (not in sync). Therefore the dataset needs to create with 50% chance also a "not in sync" example where the audio and video is not alligned.

Can you provide more detailed code? Thanks

可以参考下Wav2Lip的实现 https://github.com/Rudrabha/Wav2Lip/blob/0ec01ce80a84ede143567adb63ee773adf4e2668/color_syncnet_train.py#L82C12-L82C12

Thaks, i have done.

Hi, I wanted to know about the experiment details that you did in order to solve the syncnet training issue. Any leads are much appreciated.

lcc157 commented 1 year ago

增加数据加载,loss采用BCE,

import torch
import numpy as np
import json
import random
import cv2

from torch.utils.data import Dataset

def get_data(json_name,augment_num):
    print('start loading data')
    with open(json_name,'r') as f:
        data_dic = json.load(f)
    data_dic_name_list = []
    for augment_index in range(augment_num):
        for video_name in data_dic.keys():
            data_dic_name_list.append(video_name)
    random.shuffle(data_dic_name_list)
    print('finish loading')
    return data_dic_name_list,data_dic

class DINetDataset(Dataset):
    def __init__(self,path_json,augment_num,mouth_region_size):
        super(DINetDataset, self).__init__()
        self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num)
        self.mouth_region_size = mouth_region_size
        self.radius = mouth_region_size//2
        self.radius_1_4 = self.radius//4
        self.img_h = self.radius * 3 + self.radius_1_4
        self.img_w = self.radius * 2 + self.radius_1_4 * 2
        self.length = len(self.data_dic_name_list)

    def __getitem__(self, index):
        video_name = self.data_dic_name_list[index]
        video_clip_num = len(self.data_dic[video_name]['clip_data_list'])    
        try:      
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        except:
            print(video_name,video_clip_num)
            video_name = self.data_dic_name_list[0]
            video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        while source_anchor == wrong_source_anchor:
            wrong_source_anchor  = random.sample(range(video_clip_num), 1)[0]

        source_clip_list = []
        source_clip_mask_list = []
        deep_speech_list = []

       # if random.choice([True, False]):
        if (index & 1) == 0:
            y = torch.ones(1).float()
            chosen = source_anchor
        else:
            y = torch.zeros(1).float()
            chosen = wrong_source_anchor

        source_image_path_list = self.data_dic[video_name]['clip_data_list'][chosen]['frame_path_list']
        for source_frame_index in range(2, 2 + 5):
            ## load source clip
            source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1]
            source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0
            source_clip_list.append(source_image_data)

            ## load deep speech feature
            deepspeech_array = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][
                                       source_frame_index - 2:source_frame_index + 3])
            deep_speech_list.append(deepspeech_array)

        source_clip = np.stack(source_clip_list, 0)
        deep_speech_clip = np.stack(deep_speech_list, 0)
        #deep_speech_clip = np.reshape(deep_speech_clip,(-1,1024))
        deep_speech_full = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'])

        # # 2 tensor
        source_clip = torch.from_numpy(source_clip).float().permute(0, 3, 1, 2)
        deep_speech_full = torch.from_numpy(deep_speech_full).permute(1, 0)
        deep_speech_clip = torch.from_numpy(deep_speech_clip).permute(2,0, 1)

        return source_clip ,deep_speech_full,y

    def __len__(self):
        return self.length
ZYBOBO commented 1 year ago

增加数据加载,loss采用BCE,

import torch
import numpy as np
import json
import random
import cv2

from torch.utils.data import Dataset

def get_data(json_name,augment_num):
    print('start loading data')
    with open(json_name,'r') as f:
        data_dic = json.load(f)
    data_dic_name_list = []
    for augment_index in range(augment_num):
        for video_name in data_dic.keys():
            data_dic_name_list.append(video_name)
    random.shuffle(data_dic_name_list)
    print('finish loading')
    return data_dic_name_list,data_dic

class DINetDataset(Dataset):
    def __init__(self,path_json,augment_num,mouth_region_size):
        super(DINetDataset, self).__init__()
        self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num)
        self.mouth_region_size = mouth_region_size
        self.radius = mouth_region_size//2
        self.radius_1_4 = self.radius//4
        self.img_h = self.radius * 3 + self.radius_1_4
        self.img_w = self.radius * 2 + self.radius_1_4 * 2
        self.length = len(self.data_dic_name_list)

    def __getitem__(self, index):
        video_name = self.data_dic_name_list[index]
        video_clip_num = len(self.data_dic[video_name]['clip_data_list'])    
        try:      
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        except:
            print(video_name,video_clip_num)
            video_name = self.data_dic_name_list[0]
            video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        while source_anchor == wrong_source_anchor:
            wrong_source_anchor  = random.sample(range(video_clip_num), 1)[0]

        source_clip_list = []
        source_clip_mask_list = []
        deep_speech_list = []

       # if random.choice([True, False]):
        if (index & 1) == 0:
            y = torch.ones(1).float()
            chosen = source_anchor
        else:
            y = torch.zeros(1).float()
            chosen = wrong_source_anchor

        source_image_path_list = self.data_dic[video_name]['clip_data_list'][chosen]['frame_path_list']
        for source_frame_index in range(2, 2 + 5):
            ## load source clip
            source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1]
            source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0
            source_clip_list.append(source_image_data)

            ## load deep speech feature
            deepspeech_array = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][
                                       source_frame_index - 2:source_frame_index + 3])
            deep_speech_list.append(deepspeech_array)

        source_clip = np.stack(source_clip_list, 0)
        deep_speech_clip = np.stack(deep_speech_list, 0)
        #deep_speech_clip = np.reshape(deep_speech_clip,(-1,1024))
        deep_speech_full = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'])

        # # 2 tensor
        source_clip = torch.from_numpy(source_clip).float().permute(0, 3, 1, 2)
        deep_speech_full = torch.from_numpy(deep_speech_full).permute(1, 0)
        deep_speech_clip = torch.from_numpy(deep_speech_clip).permute(2,0, 1)

        return source_clip ,deep_speech_full,y

    def __len__(self):
        return self.length

感谢提供的数据加载代码,但是DINet的同步网络syncnet的输出是[b, 1, 1, 8, 8]五维度的,而上述代码中的y是三维度的[b,1, 1],需要将网络输出进行调整

decajcd commented 1 year ago

增加数据加载,loss采用BCE,

import torch
import numpy as np
import json
import random
import cv2

from torch.utils.data import Dataset

def get_data(json_name,augment_num):
    print('start loading data')
    with open(json_name,'r') as f:
        data_dic = json.load(f)
    data_dic_name_list = []
    for augment_index in range(augment_num):
        for video_name in data_dic.keys():
            data_dic_name_list.append(video_name)
    random.shuffle(data_dic_name_list)
    print('finish loading')
    return data_dic_name_list,data_dic

class DINetDataset(Dataset):
    def __init__(self,path_json,augment_num,mouth_region_size):
        super(DINetDataset, self).__init__()
        self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num)
        self.mouth_region_size = mouth_region_size
        self.radius = mouth_region_size//2
        self.radius_1_4 = self.radius//4
        self.img_h = self.radius * 3 + self.radius_1_4
        self.img_w = self.radius * 2 + self.radius_1_4 * 2
        self.length = len(self.data_dic_name_list)

    def __getitem__(self, index):
        video_name = self.data_dic_name_list[index]
        video_clip_num = len(self.data_dic[video_name]['clip_data_list'])    
        try:      
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        except:
            print(video_name,video_clip_num)
            video_name = self.data_dic_name_list[0]
            video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        while source_anchor == wrong_source_anchor:
            wrong_source_anchor  = random.sample(range(video_clip_num), 1)[0]

        source_clip_list = []
        source_clip_mask_list = []
        deep_speech_list = []

       # if random.choice([True, False]):
        if (index & 1) == 0:
            y = torch.ones(1).float()
            chosen = source_anchor
        else:
            y = torch.zeros(1).float()
            chosen = wrong_source_anchor

        source_image_path_list = self.data_dic[video_name]['clip_data_list'][chosen]['frame_path_list']
        for source_frame_index in range(2, 2 + 5):
            ## load source clip
            source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1]
            source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0
            source_clip_list.append(source_image_data)

            ## load deep speech feature
            deepspeech_array = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][
                                       source_frame_index - 2:source_frame_index + 3])
            deep_speech_list.append(deepspeech_array)

        source_clip = np.stack(source_clip_list, 0)
        deep_speech_clip = np.stack(deep_speech_list, 0)
        #deep_speech_clip = np.reshape(deep_speech_clip,(-1,1024))
        deep_speech_full = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'])

        # # 2 tensor
        source_clip = torch.from_numpy(source_clip).float().permute(0, 3, 1, 2)
        deep_speech_full = torch.from_numpy(deep_speech_full).permute(1, 0)
        deep_speech_clip = torch.from_numpy(deep_speech_clip).permute(2,0, 1)

        return source_clip ,deep_speech_full,y

    def __len__(self):
        return self.length

感谢提供的数据加载代码,但是DINet的同步网络syncnet的输出是[b, 1, 1, 8, 8]五维度的,而上述代码中的y是三维度的[b,1, 1],需要将网络输出进行调整

请问维度不一致怎么解决的

decajcd commented 1 year ago

In my situation, the loss(MSE) of syncnet decreases dramatically to 1e-3 level in the first epoch, and I think something must go wrong. Any advices would be appreciated. BTW, using BCE would report an error in loss calculation: Assertion input_val >= zero && input_val <= one failed.

Have you resolved?

jinlingxueluo commented 1 year ago

增加数据加载,loss采用BCE,

import torch
import numpy as np
import json
import random
import cv2

from torch.utils.data import Dataset

def get_data(json_name,augment_num):
    print('start loading data')
    with open(json_name,'r') as f:
        data_dic = json.load(f)
    data_dic_name_list = []
    for augment_index in range(augment_num):
        for video_name in data_dic.keys():
            data_dic_name_list.append(video_name)
    random.shuffle(data_dic_name_list)
    print('finish loading')
    return data_dic_name_list,data_dic

class DINetDataset(Dataset):
    def __init__(self,path_json,augment_num,mouth_region_size):
        super(DINetDataset, self).__init__()
        self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num)
        self.mouth_region_size = mouth_region_size
        self.radius = mouth_region_size//2
        self.radius_1_4 = self.radius//4
        self.img_h = self.radius * 3 + self.radius_1_4
        self.img_w = self.radius * 2 + self.radius_1_4 * 2
        self.length = len(self.data_dic_name_list)

    def __getitem__(self, index):
        video_name = self.data_dic_name_list[index]
        video_clip_num = len(self.data_dic[video_name]['clip_data_list'])    
        try:      
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        except:
            print(video_name,video_clip_num)
            video_name = self.data_dic_name_list[0]
            video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        while source_anchor == wrong_source_anchor:
            wrong_source_anchor  = random.sample(range(video_clip_num), 1)[0]

        source_clip_list = []
        source_clip_mask_list = []
        deep_speech_list = []

       # if random.choice([True, False]):
        if (index & 1) == 0:
            y = torch.ones(1).float()
            chosen = source_anchor
        else:
            y = torch.zeros(1).float()
            chosen = wrong_source_anchor

        source_image_path_list = self.data_dic[video_name]['clip_data_list'][chosen]['frame_path_list']
        for source_frame_index in range(2, 2 + 5):
            ## load source clip
            source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1]
            source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0
            source_clip_list.append(source_image_data)

            ## load deep speech feature
            deepspeech_array = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][
                                       source_frame_index - 2:source_frame_index + 3])
            deep_speech_list.append(deepspeech_array)

        source_clip = np.stack(source_clip_list, 0)
        deep_speech_clip = np.stack(deep_speech_list, 0)
        #deep_speech_clip = np.reshape(deep_speech_clip,(-1,1024))
        deep_speech_full = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'])

        # # 2 tensor
        source_clip = torch.from_numpy(source_clip).float().permute(0, 3, 1, 2)
        deep_speech_full = torch.from_numpy(deep_speech_full).permute(1, 0)
        deep_speech_clip = torch.from_numpy(deep_speech_clip).permute(2,0, 1)

        return source_clip ,deep_speech_full,y

    def __len__(self):
        return self.length

感谢提供的数据加载代码,但是DINet的同步网络syncnet的输出是[b, 1, 1, 8, 8]五维度的,而上述代码中的y是三维度的[b,1, 1],需要将网络输出进行调整

请问维度不一致怎么解决的

我将同步网络的输出进行自适应平均池化操作,然后再进行损失计算,但这样做并不能解决损失不下降的问题

lcc157 commented 1 year ago

增加数据加载,loss采用BCE,

import torch
import numpy as np
import json
import random
import cv2

from torch.utils.data import Dataset

def get_data(json_name,augment_num):
    print('start loading data')
    with open(json_name,'r') as f:
        data_dic = json.load(f)
    data_dic_name_list = []
    for augment_index in range(augment_num):
        for video_name in data_dic.keys():
            data_dic_name_list.append(video_name)
    random.shuffle(data_dic_name_list)
    print('finish loading')
    return data_dic_name_list,data_dic

class DINetDataset(Dataset):
    def __init__(self,path_json,augment_num,mouth_region_size):
        super(DINetDataset, self).__init__()
        self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num)
        self.mouth_region_size = mouth_region_size
        self.radius = mouth_region_size//2
        self.radius_1_4 = self.radius//4
        self.img_h = self.radius * 3 + self.radius_1_4
        self.img_w = self.radius * 2 + self.radius_1_4 * 2
        self.length = len(self.data_dic_name_list)

    def __getitem__(self, index):
        video_name = self.data_dic_name_list[index]
        video_clip_num = len(self.data_dic[video_name]['clip_data_list'])    
        try:      
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        except:
            print(video_name,video_clip_num)
            video_name = self.data_dic_name_list[0]
            video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        while source_anchor == wrong_source_anchor:
            wrong_source_anchor  = random.sample(range(video_clip_num), 1)[0]

        source_clip_list = []
        source_clip_mask_list = []
        deep_speech_list = []

       # if random.choice([True, False]):
        if (index & 1) == 0:
            y = torch.ones(1).float()
            chosen = source_anchor
        else:
            y = torch.zeros(1).float()
            chosen = wrong_source_anchor

        source_image_path_list = self.data_dic[video_name]['clip_data_list'][chosen]['frame_path_list']
        for source_frame_index in range(2, 2 + 5):
            ## load source clip
            source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1]
            source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0
            source_clip_list.append(source_image_data)

            ## load deep speech feature
            deepspeech_array = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][
                                       source_frame_index - 2:source_frame_index + 3])
            deep_speech_list.append(deepspeech_array)

        source_clip = np.stack(source_clip_list, 0)
        deep_speech_clip = np.stack(deep_speech_list, 0)
        #deep_speech_clip = np.reshape(deep_speech_clip,(-1,1024))
        deep_speech_full = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'])

        # # 2 tensor
        source_clip = torch.from_numpy(source_clip).float().permute(0, 3, 1, 2)
        deep_speech_full = torch.from_numpy(deep_speech_full).permute(1, 0)
        deep_speech_clip = torch.from_numpy(deep_speech_clip).permute(2,0, 1)

        return source_clip ,deep_speech_full,y

    def __len__(self):
        return self.length

感谢提供的数据加载代码,但是DINet的同步网络syncnet的输出是[b, 1, 1, 8, 8]五维度的,而上述代码中的y是三维度的[b,1, 1],需要将网络输出进行调整

请问维度不一致怎么解决的

在鉴别器网络增加池化层,以统一维度

lcc157 commented 1 year ago

增加数据加载,loss采用BCE,

import torch
import numpy as np
import json
import random
import cv2

from torch.utils.data import Dataset

def get_data(json_name,augment_num):
    print('start loading data')
    with open(json_name,'r') as f:
        data_dic = json.load(f)
    data_dic_name_list = []
    for augment_index in range(augment_num):
        for video_name in data_dic.keys():
            data_dic_name_list.append(video_name)
    random.shuffle(data_dic_name_list)
    print('finish loading')
    return data_dic_name_list,data_dic

class DINetDataset(Dataset):
    def __init__(self,path_json,augment_num,mouth_region_size):
        super(DINetDataset, self).__init__()
        self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num)
        self.mouth_region_size = mouth_region_size
        self.radius = mouth_region_size//2
        self.radius_1_4 = self.radius//4
        self.img_h = self.radius * 3 + self.radius_1_4
        self.img_w = self.radius * 2 + self.radius_1_4 * 2
        self.length = len(self.data_dic_name_list)

    def __getitem__(self, index):
        video_name = self.data_dic_name_list[index]
        video_clip_num = len(self.data_dic[video_name]['clip_data_list'])    
        try:      
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        except:
            print(video_name,video_clip_num)
            video_name = self.data_dic_name_list[0]
            video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        while source_anchor == wrong_source_anchor:
            wrong_source_anchor  = random.sample(range(video_clip_num), 1)[0]

        source_clip_list = []
        source_clip_mask_list = []
        deep_speech_list = []

       # if random.choice([True, False]):
        if (index & 1) == 0:
            y = torch.ones(1).float()
            chosen = source_anchor
        else:
            y = torch.zeros(1).float()
            chosen = wrong_source_anchor

        source_image_path_list = self.data_dic[video_name]['clip_data_list'][chosen]['frame_path_list']
        for source_frame_index in range(2, 2 + 5):
            ## load source clip
            source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1]
            source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0
            source_clip_list.append(source_image_data)

            ## load deep speech feature
            deepspeech_array = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][
                                       source_frame_index - 2:source_frame_index + 3])
            deep_speech_list.append(deepspeech_array)

        source_clip = np.stack(source_clip_list, 0)
        deep_speech_clip = np.stack(deep_speech_list, 0)
        #deep_speech_clip = np.reshape(deep_speech_clip,(-1,1024))
        deep_speech_full = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'])

        # # 2 tensor
        source_clip = torch.from_numpy(source_clip).float().permute(0, 3, 1, 2)
        deep_speech_full = torch.from_numpy(deep_speech_full).permute(1, 0)
        deep_speech_clip = torch.from_numpy(deep_speech_clip).permute(2,0, 1)

        return source_clip ,deep_speech_full,y

    def __len__(self):
        return self.length

感谢提供的数据加载代码,但是DINet的同步网络syncnet的输出是[b, 1, 1, 8, 8]五维度的,而上述代码中的y是三维度的[b,1, 1],需要将网络输出进行调整

请问维度不一致怎么解决的

我将同步网络的输出进行自适应平均池化操作,然后再进行损失计算,但这样做并不能解决损失不下降的问题

训练了多少个周期呢,我大约在40epoch以上才开始收敛,收敛到0.2-0.3,可以加大训练的batch试试

jinlingxueluo commented 1 year ago

增加数据加载,loss采用BCE,

import torch
import numpy as np
import json
import random
import cv2

from torch.utils.data import Dataset

def get_data(json_name,augment_num):
    print('start loading data')
    with open(json_name,'r') as f:
        data_dic = json.load(f)
    data_dic_name_list = []
    for augment_index in range(augment_num):
        for video_name in data_dic.keys():
            data_dic_name_list.append(video_name)
    random.shuffle(data_dic_name_list)
    print('finish loading')
    return data_dic_name_list,data_dic

class DINetDataset(Dataset):
    def __init__(self,path_json,augment_num,mouth_region_size):
        super(DINetDataset, self).__init__()
        self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num)
        self.mouth_region_size = mouth_region_size
        self.radius = mouth_region_size//2
        self.radius_1_4 = self.radius//4
        self.img_h = self.radius * 3 + self.radius_1_4
        self.img_w = self.radius * 2 + self.radius_1_4 * 2
        self.length = len(self.data_dic_name_list)

    def __getitem__(self, index):
        video_name = self.data_dic_name_list[index]
        video_clip_num = len(self.data_dic[video_name]['clip_data_list'])    
        try:      
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        except:
            print(video_name,video_clip_num)
            video_name = self.data_dic_name_list[0]
            video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        while source_anchor == wrong_source_anchor:
            wrong_source_anchor  = random.sample(range(video_clip_num), 1)[0]

        source_clip_list = []
        source_clip_mask_list = []
        deep_speech_list = []

       # if random.choice([True, False]):
        if (index & 1) == 0:
            y = torch.ones(1).float()
            chosen = source_anchor
        else:
            y = torch.zeros(1).float()
            chosen = wrong_source_anchor

        source_image_path_list = self.data_dic[video_name]['clip_data_list'][chosen]['frame_path_list']
        for source_frame_index in range(2, 2 + 5):
            ## load source clip
            source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1]
            source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0
            source_clip_list.append(source_image_data)

            ## load deep speech feature
            deepspeech_array = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][
                                       source_frame_index - 2:source_frame_index + 3])
            deep_speech_list.append(deepspeech_array)

        source_clip = np.stack(source_clip_list, 0)
        deep_speech_clip = np.stack(deep_speech_list, 0)
        #deep_speech_clip = np.reshape(deep_speech_clip,(-1,1024))
        deep_speech_full = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'])

        # # 2 tensor
        source_clip = torch.from_numpy(source_clip).float().permute(0, 3, 1, 2)
        deep_speech_full = torch.from_numpy(deep_speech_full).permute(1, 0)
        deep_speech_clip = torch.from_numpy(deep_speech_clip).permute(2,0, 1)

        return source_clip ,deep_speech_full,y

    def __len__(self):
        return self.length

感谢提供的数据加载代码,但是DINet的同步网络syncnet的输出是[b, 1, 1, 8, 8]五维度的,而上述代码中的y是三维度的[b,1, 1],需要将网络输出进行调整

请问维度不一致怎么解决的

我将同步网络的输出进行自适应平均池化操作,然后再进行损失计算,但这样做并不能解决损失不下降的问题

训练了多少个周期呢,我大约在40epoch以上才开始收敛,收敛到0.2-0.3,可以加大训练的batch试试

我用的数据集就是HDFT,大概一两个批次就下降到很低了,具体数值我忘了,但是后面就一直不怎么动了,而且拿来当作预训练的同步网络来推理DINet效果并不好

NaMoCv commented 12 months ago

增加数据加载,loss采用BCE,

import torch
import numpy as np
import json
import random
import cv2

from torch.utils.data import Dataset

def get_data(json_name,augment_num):
    print('start loading data')
    with open(json_name,'r') as f:
        data_dic = json.load(f)
    data_dic_name_list = []
    for augment_index in range(augment_num):
        for video_name in data_dic.keys():
            data_dic_name_list.append(video_name)
    random.shuffle(data_dic_name_list)
    print('finish loading')
    return data_dic_name_list,data_dic

class DINetDataset(Dataset):
    def __init__(self,path_json,augment_num,mouth_region_size):
        super(DINetDataset, self).__init__()
        self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num)
        self.mouth_region_size = mouth_region_size
        self.radius = mouth_region_size//2
        self.radius_1_4 = self.radius//4
        self.img_h = self.radius * 3 + self.radius_1_4
        self.img_w = self.radius * 2 + self.radius_1_4 * 2
        self.length = len(self.data_dic_name_list)

    def __getitem__(self, index):
        video_name = self.data_dic_name_list[index]
        video_clip_num = len(self.data_dic[video_name]['clip_data_list'])    
        try:      
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        except:
            print(video_name,video_clip_num)
            video_name = self.data_dic_name_list[0]
            video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
            source_anchor = random.sample(range(video_clip_num), 1)[0]
            wrong_source_anchor = random.sample(range(video_clip_num), 1)[0]
        while source_anchor == wrong_source_anchor:
            wrong_source_anchor  = random.sample(range(video_clip_num), 1)[0]

        source_clip_list = []
        source_clip_mask_list = []
        deep_speech_list = []

       # if random.choice([True, False]):
        if (index & 1) == 0:
            y = torch.ones(1).float()
            chosen = source_anchor
        else:
            y = torch.zeros(1).float()
            chosen = wrong_source_anchor

        source_image_path_list = self.data_dic[video_name]['clip_data_list'][chosen]['frame_path_list']
        for source_frame_index in range(2, 2 + 5):
            ## load source clip
            source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1]
            source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0
            source_clip_list.append(source_image_data)

            ## load deep speech feature
            deepspeech_array = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][
                                       source_frame_index - 2:source_frame_index + 3])
            deep_speech_list.append(deepspeech_array)

        source_clip = np.stack(source_clip_list, 0)
        deep_speech_clip = np.stack(deep_speech_list, 0)
        #deep_speech_clip = np.reshape(deep_speech_clip,(-1,1024))
        deep_speech_full = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'])

        # # 2 tensor
        source_clip = torch.from_numpy(source_clip).float().permute(0, 3, 1, 2)
        deep_speech_full = torch.from_numpy(deep_speech_full).permute(1, 0)
        deep_speech_clip = torch.from_numpy(deep_speech_clip).permute(2,0, 1)

        return source_clip ,deep_speech_full,y

    def __len__(self):
        return self.length

感谢提供的数据加载代码,但是DINet的同步网络syncnet的输出是[b, 1, 1, 8, 8]五维度的,而上述代码中的y是三维度的[b,1, 1],需要将网络输出进行调整

请问维度不一致怎么解决的

我将同步网络的输出进行自适应平均池化操作,然后再进行损失计算,但这样做并不能解决损失不下降的问题

训练了多少个周期呢,我大约在40epoch以上才开始收敛,收敛到0.2-0.3,可以加大训练的batch试试

有设置验证集吗?我设置了验证集,但是训练集在收敛,验证集的loss缺在变大,真的太奇怪了。

jinwonkim93 commented 8 months ago

@lcc157 what batch size did you use for training syncnet?

tailangjun commented 6 months ago

@codersun123你好你用哪个中文数据集,自制的吗

自制了一个,视频来源是bilibili上的演讲视频,我参考了你分享的和wav2lip的代码,并改成了BCEloss,如果有什么可以分享的话可以和我联系,codesun1234567 @gmail.com

对中文效果好吗,我最近还在做数据,这是我参考wav2lip对于BCE需要的处理。我试了用1个小时的数据训练去训练syncnet,MSE和BCE都第1e-3,是数据量太少吗还是读取数据不应该随机取数据` def front(self, image,audio): image_embedding = self.face_encoder(image) audio_embedding = self.audio_encoder(audio) audio_embedding = audio_embedding.view(audio_embedding.size(0) ,-1)face_embedding = image_embedding.view(image_embedding.size(0),-1)

    audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
    face_embedding = F.normalize(face_embedding, p=2, dim=1)
    out_score= nn.functional.cosine_similarity(audio_embedding, face_embedding)
    return out_score`

我的损失到了0.69就降不下去了,目前还没有解决这个问题

你后面解决这个问题没,老铁