facebookresearch / VisualVoice

Audio-Visual Speech Separation with Cross-Modal Consistency
Other
219 stars 35 forks source link

About #20

Open dengyuanjie opened 2 years ago

dengyuanjie commented 2 years ago

Thank you very much for your excellent work. One problem I am confused about is the definition of the crossmodal loss function andcoseparation loss function. In the train.py, why random numbers and opt.gt_percentage are used to select which audio feature (audio_embedding_A1_pred or audio_embedding_A1_gt) is used. According to the method of the paper, shouldn't the predictive features be used?

def get_coseparation_loss(output, opt, loss_triplet):
if random.random() > opt.gt_percentage:
audio_embeddings_A1 = output['audio_embedding_A1_pred']
audio_embeddings_A2 = output['audio_embedding_A2_pred']
audio_embeddings_B1 = output['audio_embedding_B1_pred']
audio_embeddings_B2 = output['audio_embedding_B2_pred']
else:
audio_embeddings_A1 = output['audio_embedding_A1_gt']
audio_embeddings_A2 = output['audio_embedding_A2_gt']
audio_embeddings_B1 = output['audio_embedding_B_gt']
audio_embeddings_B2 = output['audio_embedding_B_gt']

coseparation_loss = loss_triplet(audio_embeddings_A1, audio_embeddings_A2, audio_embeddings_B1) + loss_triplet(audio_embeddings_A1, audio_embeddings_A2, audio_embeddings_B2)
return coseparation_loss
def get_crossmodal_loss(output, opt, loss_triplet):
identity_feature_A = output['identity_feature_A']
identity_feature_B = output['identity_feature_B']
if random.random() > opt.gt_percentage:
audio_embeddings_A1 = output['audio_embedding_A1_pred']
audio_embeddings_A2 = output['audio_embedding_A2_pred']
audio_embeddings_B1 = output['audio_embedding_B1_pred']
audio_embeddings_B2 = output['audio_embedding_B2_pred']
else:
audio_embeddings_A1 = output['audio_embedding_A1_gt']
audio_embeddings_A2 = output['audio_embedding_A2_gt']
audio_embeddings_B1 = output['audio_embedding_B_gt']
audio_embeddings_B2 = output['audio_embedding_B_gt']
crossmodal_loss = loss_triplet(audio_embeddings_A1, identity_feature_A, identity_feature_B) + loss_triplet(audio_embeddings_A2, identity_feature_A, identity_feature_B) + loss_triplet(audio_embeddings_B1, identity_feature_B, identity_feature_A) + loss_triplet(audio_embeddings_B2, identity_feature_B, identity_feature_A)
return crossmodal_loss`
rhgao commented 2 years ago

Because the separated sound can be of low quality, especially in the initial stage of training, making their embeddings bad. However, the ground-truth audios/embeddings are always reliable. So this step is just to make sure the cross-modal loss or co-separation loss learns from meaningful feature embeddings that lead to meaningful distance metrics in the embedding space. Then, the loss using the predicted embeddings is the part that actually helps the separation learning.

dengyuanjie commented 2 years ago

Thanks for your quick reply!