NXTProduct / TUNet

52 stars 16 forks source link

MSM pretraining #5

Closed utahboy3502 closed 2 years ago

utahboy3502 commented 2 years ago

Hello.

I was running your code but I found out that MSM pretraining is not working as your paper described. If I understand correctly, MSM pretraining takes NB input, split it into multiple blocks, and then zero masking one of the blocks randomly via setting hyperparameters of mask chunk and mask ratio. However, it seems like it is not working as intended after I checked the tensorboard logger

Here is the screenshot of tensorboard output in stage of MSM pretraining

CleanShot 2022-10-16 at 17 16 44@2x

I was wondering if the code is wrong. Can you please check it?

Also, I have found out that implementation of pretraning has worse performance rather than just BWE baseline The only modification I did was change of batch size from 80 to 16.

CleanShot 2022-10-16 at 17 29 39

The upper one is BWE and the bottom one is MSM + BWE

anhnv125 commented 2 years ago

Sorry for the late response. In the image you attached, it worked as intended. I guess you might wondering about the high frequency bands of the narrowband input. This MSM is inspired by packet loss concealment training. Since we upsample to 16 kHz before applying zero mask, it will create spectral leakage. You might apply the loss mask before upsampling to filter those frequencies, but it is not our intention. For the performance degradation, could you provide more details on your experiments? e.g., dataset you pretrained on, dataset you trained BW on.

utahboy3502 commented 2 years ago

Thanks for the reply,

OK. I now see how MSM pretraining works.

For performance degradation, I was using same CONFIG and dataset I used was VCTK

To verify the result, I re-trained it with different epoch number After I followed your previous comment on how many epoch for pretraining: MSM pretraining 50 epoch, and BWE for 150 epoch

I was able to get this result.

CleanShot 2022-10-28 at 16 11 38

There is a performance improvement on the LSD (1.32 -> 1.29)

It works!

Sorry for the late response. In the image you attached, it worked as intended. I guess you might wondering about the high frequency bands of the narrowband input. This MSM is inspired by packet loss concealment training. Since we upsample to 16 kHz before applying zero mask, it will create spectral leakage. You might apply the loss mask before upsampling to filter those frequencies, but it is not our intention. For the performance degradation, could you provide more details on your experiments? e.g., dataset you pretrained on, dataset you trained BW on.

anhnv125 commented 2 years ago

Nice, great to hear that. I will close the issue.