ductuantruong / enskd

Official implementation of the ICASSP 2024 paper: Emphasized Non-Target Speaker Knowledge in Knowledge Distillation for Speaker Verification
https://arxiv.org/abs/2309.14838
MIT License
14 stars 0 forks source link

Questions about training config options (speed_perturb: False) #1

Closed greatnoble closed 6 months ago

greatnoble commented 6 months ago

@ductuantruong

Thank you for sharing your excellent research.

I would like to inquire about training config options.

It is known that applying speech perturbation in speaker verification performance helps improve model generalization performance.

However, in the training config you shared, the speech perturb option is set to "False".

When I train the model by changing the speech perturb option you shared to "True", an error occurs as shown below.

If I set the speech perturb option to True in the training config, shouldn't it be able to train?

/opt/conda/conda-bld/pytorch_1659484806139/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [24,0,0] Assertionidx_dim >= 0 && idx_dim < index_size && "index out of bounds"failed. /opt/conda/conda-bld/pytorch_1659484806139/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [25,0,0] Assertionidx_dim >= 0 && idx_dim < index_size && "index out of bounds"failed. /opt/conda/conda-bld/pytorch_1659484806139/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [26,0,0] Assertionidx_dim >= 0 && idx_dim < index_size && "index out of bounds"failed. /opt/conda/conda-bld/pytorch_1659484806139/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [28,0,0] Assertionidx_dim >= 0 && idx_dim < index_size && "index out of bounds"failed. /opt/conda/conda-bld/pytorch_1659484806139/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [30,0,0] Assertionidx_dim >= 0 && idx_dim < index_size && "index out of bounds"failed. Traceback (most recent call last): File "/home/azureuser/data2/enskd/wespeaker/bin/train.py", line 254, in <module> fire.Fire(train) File "/home/azureuser/miniconda3/envs/wespeaker/lib/python3.9/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/home/azureuser/miniconda3/envs/wespeaker/lib/python3.9/site-packages/fire/core.py", line 466, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/home/azureuser/miniconda3/envs/wespeaker/lib/python3.9/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/home/azureuser/data2/enskd/wespeaker/bin/train.py", line 224, in train run_epoch(train_dataloader, File "/home/azureuser/data2/enskd/wespeaker/utils/executor.py", line 73, in run_epoch aux_loss = aux_criterion(outputs, teacher_logits, targets, epoch) File "/home/azureuser/miniconda3/envs/wespeaker/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/azureuser/data2/enskd/wespeaker/models/dkd_loss.py", line 70, in forward loss_dkd = min(epoch / self.warmup, 1.0) * dkd_loss( File "/home/azureuser/data2/enskd/wespeaker/models/dkd_loss.py", line 12, in dkd_loss pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask) File "/home/azureuser/data2/enskd/wespeaker/models/dkd_loss.py", line 46, in cat_mask t1 = (t * mask1).sum(dim=1, keepdims=True) RuntimeError: The size of tensor a (5994) must match the size of tensor b (17982) at non-singleton dimension 1

ductuantruong commented 6 months ago

Hi @greatnoble,

Thank you for your interest in our work. We didn't use speech perturbation during training because the Wespeaker framework considers the augmented audio (by speech perturbation) as new speakers due to the pitch shift after the augmentation (https://arxiv.org/pdf/2210.17016.pdf). Therefore this will make the total number of speaker classes larger than the original training set (as shown in error log: 17982 & 5994). However, in this work, the proposed method is label-level knowledge distillation, hence the number of output speaker classes between student and teacher model must be the same. And, the WavLM-ECAPA model has 5994 speaker classes (number of speakers in Vox2 train set ). Therefore, we didn't use speech perturbation in order to keep the number of speaker classes identical in both the student and teacher models.

If you want to use speech perturbation without changing the number of speaker classes. You can simply remove this line of code. However, I believe if you apply speed perturbation in an utterance, the speaker characteristic of the origin and augmented utterances will be different.

If you feel my reply answers your question, please help me close this issue. And, if you find our code helpful, a star to this repo is much appreciated.

greatnoble commented 6 months ago

@ductuantruong

Thank you for the detailed explanation.

It was a great help in understanding your paper.

Thank you for sharing your very impressive speaker recognition research results.