MRzzm / DINet

The source code of "DINet: deformation inpainting network for realistic face visually dubbing on high resolution video."
895 stars 167 forks source link

有大佬成功训练出效果较好的 syncnet吗? #105

Closed tailangjun closed 2 months ago

tailangjun commented 2 months ago

我尝试过 criterionBCE + hubert来训练 syncnet,augment_num=40,到第 18个 epoch时,loss可以降到 0.12,这低得有点不太正常了,我之前在 Wav2lip中训练 loss差不多在 0.21左右,具体是什么问题,我还没搞明白。 后面我就训练 frame64、frame128、frame256,其实 frame256训练完,使用训练集的语音来推理效果是很好的

https://github.com/MRzzm/DINet/assets/12316965/dc20924f-4f62-408b-9735-8fdb2bf7db54

接着训练 clip256,越训练脸部错位就越严重,怀疑是 syncnet没训练好

https://github.com/MRzzm/DINet/assets/12316965/3b06e810-c51e-457b-a77a-6d00cddc2ea9

不知道是否有大佬遇到过类似的问题,谢谢

tailangjun commented 2 months ago

后面我改用 criterionMSE + wav2vec2来训练 syncnet,发现 loss很很快降到 0.26,然后一直在 0.25左右徘徊,感觉更不正常

xy-gao commented 2 months ago

后面我改用 criterionMSE + wav2vec2来训练 syncnet,发现 loss很很快降到 0.26,然后一直在 0.25左右徘徊,感觉更不正常

>>> y_hat = torch.Tensor([0.5,0.5])
>>> y = torch.Tensor([1,0])
>>> nn.MSELoss()(y_hat, y) 
tensor(0.2500)
>>> nn.BCELoss()(y_hat, y)  
tensor(0.6931)

If your train batch is balanced, maybe your syncnet keep predicting 0.5 for all input sample, which will result in mse=0.25 or bce=0.6931. You can print your model output during training to check it out.

tailangjun commented 2 months ago

后面我改用标准MSE + wav2vec2来训练syncnet,发现loss很快降到0.26,然后一直在0.25左右徘徊,感觉更不正常

>>> y_hat = torch.Tensor([0.5,0.5])
>>> y = torch.Tensor([1,0])
>>> nn.MSELoss()(y_hat, y) 
tensor(0.2500)
>>> nn.BCELoss()(y_hat, y)  
tensor(0.6931)

如果你的训练批次是平衡的,也许你的同步网络会继续预测所有输入样本的 0.5,这将导致 mse=0.25 或 bce=0.6931。 您可以在训练期间打印模型输出以进行检查。

好的,非常感谢

Wangman1 commented 2 weeks ago

后面我改用标准MSE + wav2vec2来训练syncnet,发现loss很快降到0.26,然后一直在0.25左右徘徊,感觉更不正常

>>> y_hat = torch.Tensor([0.5,0.5])
>>> y = torch.Tensor([1,0])
>>> nn.MSELoss()(y_hat, y) 
tensor(0.2500)
>>> nn.BCELoss()(y_hat, y)  
tensor(0.6931)

如果你的训练批次是平衡的,也许你的同步网络会继续预测所有输入样本的 0.5,这将导致 mse=0.25 或 bce=0.6931。 您可以在训练期间打印模型输出以进行检查。

好的,非常感谢

@tailangjun ,大佬您好,请问这里使用 mse=0.25 的问题有解决方法吗,在这里卡了好久了,期望得到您的回复,谢谢~~~

tailangjun commented 2 weeks ago

后面我改用标准MSE + wav2vec2来训练syncnet,发现loss很快降到0.26,然后一直在0.25左右徘徊,感觉更不正常

>>> y_hat = torch.Tensor([0.5,0.5])
>>> y = torch.Tensor([1,0])
>>> nn.MSELoss()(y_hat, y) 
tensor(0.2500)
>>> nn.BCELoss()(y_hat, y)  
tensor(0.6931)

如果你的训练批次是平衡的,也许你的同步网络会继续预测所有输入样本的 0.5,这将导致 mse=0.25 或 bce=0.6931。 您可以在训练期间打印模型输出以进行检查。

好的,非常感谢

@tailangjun ,大佬您好,请问这里使用 mse=0.25 的问题有解决方法吗,在这里卡了好久了,期望得到您的回复,谢谢~~~

我用 MSE也是是被卡在 0.25,后面就换成 BCELoss了

Wangman1 commented 2 weeks ago

后面我改用标准MSE + wav2vec2来训练syncnet,发现loss很快降到0.26,然后一直在0.25左右徘徊,感觉更不正常

>>> y_hat = torch.Tensor([0.5,0.5])
>>> y = torch.Tensor([1,0])
>>> nn.MSELoss()(y_hat, y) 
tensor(0.2500)
>>> nn.BCELoss()(y_hat, y)  
tensor(0.6931)

如果你的训练批次是平衡的,也许你的同步网络会继续预测所有输入样本的 0.5,这将导致 mse=0.25 或 bce=0.6931。 您可以在训练期间打印模型输出以进行检查。

好的,非常感谢

@tailangjun ,大佬您好,请问这里使用 mse=0.25 的问题有解决方法吗,在这里卡了好久了,期望得到您的回复,谢谢~~~

我用 MSE也是是被卡在 0.25,后面就换成 BCELoss了

用 bce 又被卡到 0.69,请问有什么解决方法吗

sunjian2015 commented 1 week ago

后面我改用标准MSE + wav2vec2来训练syncnet,发现loss很快降到0.26,然后一直在0.25左右徘徊,感觉更不正常

>>> y_hat = torch.Tensor([0.5,0.5])
>>> y = torch.Tensor([1,0])
>>> nn.MSELoss()(y_hat, y) 
tensor(0.2500)
>>> nn.BCELoss()(y_hat, y)  
tensor(0.6931)

如果你的训练批次是平衡的,也许你的同步网络会继续预测所有输入样本的 0.5,这将导致 mse=0.25 或 bce=0.6931。 您可以在训练期间打印模型输出以进行检查。

好的,非常感谢

@tailangjun ,大佬您好,请问这里使用 mse=0.25 的问题有解决方法吗,在这里卡了好久了,期望得到您的回复,谢谢~~~

我用 MSE也是是被卡在 0.25,后面就换成 BCELoss了

用 bce 又被卡到 0.69,请问有什么解决方法吗

同样的问题,请问你解决了吗?

tailangjun commented 1 week ago

后面我改用 criterionMSE + wav2vec2来训练 syncnet,发现 loss很很快降到 0.26,然后一直在 0.25左右徘徊,感觉更不正常

怀疑数据集没对齐,你可以用 syncnet_python过滤掉那些 offset不在 [-1, 1]区间的数据