BUTSpeechFIT / DiariZen

A toolkit for speaker diarization.
MIT License
141 stars 10 forks source link

Reproducing the exact results in README #1

Closed Jiltseb closed 2 weeks ago

Jiltseb commented 1 month ago

@jyhan03 Thanks a lot for the amazing work.

I was trying to reproduce the numbers shown in README for public benchmarking data; I could only Get a DER of 17 with wavLM_updated_conformer for collar 0 and DER of 10.8 for collar 0.25 on AMI, for example.

Could you please let me know the parameters used to get the exact results or if any wavfiles are missing from the test set? TIA

jyhan03 commented 1 month ago

Hi, thanks for your interest in our work. Could you share the configure file you used for inference? It seems that you achieve the same numbers as the WavLM-frozen-Conformer.

Jiltseb commented 1 month ago

for run_stage.sh, I have used the provided config file in the checkpoints directory of wavlm_updated_conformer with the name: config__2024_07_09--07_53_21.toml.

There was no best dir specifying the best model in the checkpoints directory ($diarization_dir/checkpoints/best/pytorch_model.bin), instead I was using the latest available checkpoint: checkpoints/wavlm_updated_conformer/checkpoints/epoch_0022/pytorch_model.bin. Is there a way to get access to the best model? It might be the issue.

jyhan03 commented 1 month ago

Hi, we use model averaging in our paper.

Could you just replace the diarization_dir with your downloaded checkpoint dir? Our run_stage.sh will run the model averaging automatically. See here.

Jiltseb commented 4 weeks ago

Yes, I am doing exactly that. Then what should I provide as the segmentation_model? best dir is absent and I can only replace it with a specific model.

jyhan03 commented 4 weeks ago

Aha. I see. When val_metric_summary.lst is provided, the code will ignore the best model. So you don't need to provide a best model. Actually, even provided, it shouldn't have any effects. What's your print message at the beginning of the inference? Is there anything about model averaging across 5 checkpoints?

Jiltseb commented 4 weeks ago

Yes it says about model averaging:

averaging previous 5 checkpoints to the converged moment...
Average model over 5 checkpoints...

I don't see any problem with embeddings model either, but the results are worse:

AMI 
    collar = 0,  *** OVERALL ***  DER: 17.30  JER: 23.47 
    collar = 0.25, *** OVERALL ***  DER: 10.81  JER: 23.47

AISHELL4
   COLLAR=0,  *** OVERALL ***  12.37    20.33
   collar = 0,.25, *** OVERALL ***   6.39  20.34
jyhan03 commented 4 weeks ago

Hi, I just cloned our repo and re-run inference for validation. We could achieve the same number as we stated in the README. My inference/configuration message is below:

stage2: model inference...
...
Namespace(configuration='/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/config__2024_07_09--07_53_21.toml', 
in_wav_scp='/PATH/DiariZen/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/wav.scp', 
out_dir='/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/infer_debug_constrained_AHC_segmentation_step_0.1_min_cluster_size_30_AHC_thres_0.70_pyan_max_length_merged50/metric_Loss_prev/avg_ckpt5/test/AMI', 
uem_file='/PATH/DiariZen/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/all.uem', avg_ckpt_num=5, val_metric='Loss', val_mode='prev', 
val_metric_summary='/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/val_metric_summary.lst', 
segmentation_model='/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/best/pytorch_model.bin', 
embedding_model='/PATH/pretrained/pyannote3/wespeaker-voxceleb-resnet34-LM/pytorch_model.bin', 
max_speakers=8, batch_size=32, max_length_merged='50', merge_closer='0.5', cluster_threshold=0.7, min_cluster_size=30, segmentation_step=0.1)
Average model over 5 checkpoints...
[{'epoch': 18, 'bin_path': PosixPath('/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0018/pytorch_model.bin'), 'Loss': 0.365, 'DER': 0.359}, 
{'epoch': 19, 'bin_path': PosixPath('/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0019/pytorch_model.bin'), 'Loss': 0.369, 'DER': 0.62}, 
{'epoch': 20, 'bin_path': PosixPath('/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0020/pytorch_model.bin'), 'Loss': 0.364, 'DER': 0.744}, 
{'epoch': 21, 'bin_path': PosixPath('/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0021/pytorch_model.bin'), 'Loss': 0.361, 'DER': 0.478}, 
{'epoch': 22, 'bin_path': PosixPath('/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0022/pytorch_model.bin'), 'Loss': 0.357, 'DER': 0.613}]
...

Everything works well for me so currently I have no idea what happened to you... Our AMI test data is a multi-channel version and we always use the first channel for inference. You can find the data list here.

Jiltseb commented 4 weeks ago

Thanks a lot for verifying the results! I will check them. I was using Mix-Headset versions and using this pyannote scripts to download the audio files.

jyhan03 commented 4 weeks ago

Thanks a lot for verifying the results! I will check them. I was using Mix-Headset versions and using this pyannote scripts to download the audio files.

You’re welcome. I guess this is the reason. Our model is pre-trained with only SDM data and the results in our paper are all far-field SDM. Could you try to run the inference with SDM data?

Jiltseb commented 4 weeks ago

Sorry to bother you again! This time I used SDM version. I could see some differences in your configuration such as min_cluster_size=30, which is absent in the run_stage.sh function. I tried with the same settings based on your configuration and here is the AMI collar0 results: *** OVERALL *** DER: 17.69 JER: 22.34, Not sure if the audios are the same since what I downloaded has suffix Array1-01.wav.

My inference/configuration message is below:

stage2: model inference...
Namespace(configuration='/PATH/diarization/checkpoints/wavlm_updated_conformer//config__2024_07_09--07_53_21.toml', in_wav_scp='/PATH/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/wav.scp', out_dir='/PATH/diarization/checkpoints/wavlm_updated_conformer//infer_segmentation_step_0.1_min_cluster_size_30_AHC_thres_0.70_pyan_max_length_merged50/metric_Loss_prev/avg_ckpt5/test/AMI', uem_file='/PATH/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/all.uem', avg_ckpt_num=5, val_metric='Loss', val_mode='prev', val_metric_summary='/PATH/checkpoints/wavlm_updated_conformer//val_metric_summary.lst', segmentation_model=None, embedding_model='/PATH/wespeaker-embeddings/wespeaker-voxceleb-resnet34-LM/pyannote_pytorch_model.bin', max_speakers=8, batch_size=32, max_length_merged='50', merge_closer='0.5', cluster_threshold=0.7, min_cluster_size=30, segmentation_step=0.1)
averaging previous 5 checkpoints to the converged moment...
Average model over 5 checkpoints...
[{'epoch': 18, 'bin_path': PosixPath('/PATH/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0018/pytorch_model.bin'), 'Loss': 0.365, 'DER': 0.359}, {'epoch': 19, 'bin_path': PosixPath('/PATH/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0019/pytorch_model.bin'), 'Loss': 0.369, 'DER': 0.62}, {'epoch': 20, 'bin_path': PosixPath('/PATHs/wavlm_updated_conformer/checkpoints/epoch_0020/pytorch_model.bin'), 'Loss': 0.364, 'DER': 0.744}, {'epoch': 21, 'bin_path': PosixPath('/PATH/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0021/pytorch_model.bin'), 'Loss': 0.361, 'DER': 0.478}, {'epoch': 22, 'bin_path': PosixPath('/PATH/wavlm_updated_conformer/checkpoints/epoch_0022/pytorch_model.bin'), 'Loss': 0.357, 'DER': 0.613}]
jyhan03 commented 4 weeks ago

Hi, no worries. I downloaded the SDM data you shared and re-run inference again. The results on AMI with collar0 are: *** OVERALL *** DER: 15.36 JER: 22.37.

So something must be wrong. Let's fix it offline. Could you send me your scoring file of AMI-SDM to my email? (ihan@fit.vut.cz)

Jiltseb commented 2 weeks ago

A fresh installation solved the issue.