Closed Jiltseb closed 2 weeks ago
Hi, thanks for your interest in our work.
Could you share the configure file you used for inference? It seems that you achieve the same numbers as the WavLM-frozen-Conformer
.
for run_stage.sh
, I have used the provided config file in the checkpoints
directory of wavlm_updated_conformer
with the name: config__2024_07_09--07_53_21.toml
.
There was no best
dir specifying the best model in the checkpoints directory ($diarization_dir/checkpoints/best/pytorch_model.bin), instead I was using the latest available checkpoint: checkpoints/wavlm_updated_conformer/checkpoints/epoch_0022/pytorch_model.bin
. Is there a way to get access to the best model? It might be the issue.
Hi, we use model averaging in our paper.
Could you just replace the diarization_dir with your downloaded checkpoint dir?
Our run_stage.sh
will run the model averaging automatically. See here.
Yes, I am doing exactly that. Then what should I provide as the segmentation_model? best
dir is absent and I can only replace it with a specific model.
Aha. I see. When val_metric_summary.lst
is provided, the code will ignore the best
model. So you don't need to provide a best
model. Actually, even provided, it shouldn't have any effects. What's your print message
at the beginning of the inference? Is there anything about model averaging across 5 checkpoints?
Yes it says about model averaging:
averaging previous 5 checkpoints to the converged moment...
Average model over 5 checkpoints...
I don't see any problem with embeddings
model either, but the results are worse:
AMI
collar = 0, *** OVERALL *** DER: 17.30 JER: 23.47
collar = 0.25, *** OVERALL *** DER: 10.81 JER: 23.47
AISHELL4
COLLAR=0, *** OVERALL *** 12.37 20.33
collar = 0,.25, *** OVERALL *** 6.39 20.34
Hi, I just cloned our repo and re-run inference for validation. We could achieve the same number as we stated in the README. My inference/configuration message is below:
stage2: model inference...
...
Namespace(configuration='/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/config__2024_07_09--07_53_21.toml',
in_wav_scp='/PATH/DiariZen/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/wav.scp',
out_dir='/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/infer_debug_constrained_AHC_segmentation_step_0.1_min_cluster_size_30_AHC_thres_0.70_pyan_max_length_merged50/metric_Loss_prev/avg_ckpt5/test/AMI',
uem_file='/PATH/DiariZen/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/all.uem', avg_ckpt_num=5, val_metric='Loss', val_mode='prev',
val_metric_summary='/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/val_metric_summary.lst',
segmentation_model='/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/best/pytorch_model.bin',
embedding_model='/PATH/pretrained/pyannote3/wespeaker-voxceleb-resnet34-LM/pytorch_model.bin',
max_speakers=8, batch_size=32, max_length_merged='50', merge_closer='0.5', cluster_threshold=0.7, min_cluster_size=30, segmentation_step=0.1)
Average model over 5 checkpoints...
[{'epoch': 18, 'bin_path': PosixPath('/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0018/pytorch_model.bin'), 'Loss': 0.365, 'DER': 0.359},
{'epoch': 19, 'bin_path': PosixPath('/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0019/pytorch_model.bin'), 'Loss': 0.369, 'DER': 0.62},
{'epoch': 20, 'bin_path': PosixPath('/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0020/pytorch_model.bin'), 'Loss': 0.364, 'DER': 0.744},
{'epoch': 21, 'bin_path': PosixPath('/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0021/pytorch_model.bin'), 'Loss': 0.361, 'DER': 0.478},
{'epoch': 22, 'bin_path': PosixPath('/PATH/diar_ssl_icassp25/checkpoints_debug/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0022/pytorch_model.bin'), 'Loss': 0.357, 'DER': 0.613}]
...
Everything works well for me so currently I have no idea what happened to you... Our AMI test data is a multi-channel version and we always use the first channel for inference. You can find the data list here.
Thanks a lot for verifying the results! I will check them. I was using Mix-Headset
versions and using this pyannote scripts to download the audio files.
Thanks a lot for verifying the results! I will check them. I was using
Mix-Headset
versions and using this pyannote scripts to download the audio files.
You’re welcome. I guess this is the reason. Our model is pre-trained with only SDM data and the results in our paper are all far-field SDM. Could you try to run the inference with SDM data?
Sorry to bother you again!
This time I used SDM version. I could see some differences in your configuration such as min_cluster_size=30
, which is absent in the run_stage.sh
function. I tried with the same settings based on your configuration and here is the AMI collar0 results: *** OVERALL *** DER: 17.69 JER: 22.34
, Not sure if the audios are the same since what I downloaded has suffix Array1-01.wav
.
My inference/configuration message is below:
stage2: model inference...
Namespace(configuration='/PATH/diarization/checkpoints/wavlm_updated_conformer//config__2024_07_09--07_53_21.toml', in_wav_scp='/PATH/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/wav.scp', out_dir='/PATH/diarization/checkpoints/wavlm_updated_conformer//infer_segmentation_step_0.1_min_cluster_size_30_AHC_thres_0.70_pyan_max_length_merged50/metric_Loss_prev/avg_ckpt5/test/AMI', uem_file='/PATH/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/all.uem', avg_ckpt_num=5, val_metric='Loss', val_mode='prev', val_metric_summary='/PATH/checkpoints/wavlm_updated_conformer//val_metric_summary.lst', segmentation_model=None, embedding_model='/PATH/wespeaker-embeddings/wespeaker-voxceleb-resnet34-LM/pyannote_pytorch_model.bin', max_speakers=8, batch_size=32, max_length_merged='50', merge_closer='0.5', cluster_threshold=0.7, min_cluster_size=30, segmentation_step=0.1)
averaging previous 5 checkpoints to the converged moment...
Average model over 5 checkpoints...
[{'epoch': 18, 'bin_path': PosixPath('/PATH/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0018/pytorch_model.bin'), 'Loss': 0.365, 'DER': 0.359}, {'epoch': 19, 'bin_path': PosixPath('/PATH/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0019/pytorch_model.bin'), 'Loss': 0.369, 'DER': 0.62}, {'epoch': 20, 'bin_path': PosixPath('/PATHs/wavlm_updated_conformer/checkpoints/epoch_0020/pytorch_model.bin'), 'Loss': 0.364, 'DER': 0.744}, {'epoch': 21, 'bin_path': PosixPath('/PATH/checkpoints/wavlm_updated_conformer/checkpoints/epoch_0021/pytorch_model.bin'), 'Loss': 0.361, 'DER': 0.478}, {'epoch': 22, 'bin_path': PosixPath('/PATH/wavlm_updated_conformer/checkpoints/epoch_0022/pytorch_model.bin'), 'Loss': 0.357, 'DER': 0.613}]
Hi, no worries. I downloaded the SDM data you shared and re-run inference again. The results on AMI with collar0 are: *** OVERALL *** DER: 15.36 JER: 22.37
.
So something must be wrong. Let's fix it offline. Could you send me your scoring file of AMI-SDM to my email? (ihan@fit.vut.cz)
A fresh installation solved the issue.
@jyhan03 Thanks a lot for the amazing work.
I was trying to reproduce the numbers shown in README for public benchmarking data; I could only Get a DER of 17 with
wavLM_updated_conformer
for collar 0 and DER of 10.8 for collar 0.25 on AMI, for example.Could you please let me know the parameters used to get the exact results or if any wavfiles are missing from the test set? TIA