Audio-WestlakeU / FS-EEND

The official Pytorch implementation of "Frame-wise streaming end-to-end speaker diarization with non-autoregressive self-attention-based attractors". [ICASSP 2024]
MIT License
71 stars 4 forks source link

use pre-trained model infer dataset #7

Closed DTDwind closed 6 months ago

DTDwind commented 8 months ago

老师您好,我想要尝试在不经过finetune的情况下使用simu_avg_41_50epo.ckpt预训练模型评估ami的效能。 但是在这个过程中,我因为遇到了许多挫折儿感到迷茫,不知道自己是否在正确的道路上前行,希望可以得到老师的指点。 以下是我目前所尝试的过程: 首先,我准备了kaldi格式的ami测试集,并创建一份spk_onl_tfm_enc_dec_nonautoreg_infer.yaml的拷贝,命名为ami_infer.yaml,修改其中train_data_dir与val_data_dir为我测试集位置。 之后运行 python train_dia.py --configs conf/ami_infer.yaml --gpus 0 --test_from_folder FS-EEND_simu_41_50epo_avg_model 我遇到错误抓不到ckpt档,因此我修改 ckpts = [x for x in all_files if (".ckpt" in x) and ("epoch" in x) and int(x.split("=")[1].split("-")[0])>=configs["log"]["start_epoch"] and int(x.split("=")[1].split("-")[0])<=configs["log"]["end_epoch"]]ckpts = [x for x in all_files if (".ckpt" in x)] 但接着我遇到错误

Traceback (most recent call last):
  File "/mnt/HDD/HDD2/DTDwind/FS-EEND/train_dia.py", line 217, in <module>
    train(configs, gpus=setup.gpus, checkpoint_resume=setup.checkpoint_resume, test_folder=setup.test_from_folder)
  File "/mnt/HDD/HDD2/DTDwind/FS-EEND/train_dia.py", line 185, in train
    for name, param in state_dict.items():
AttributeError: 'float' object has no attribute 'items'

我发现程式无法正确读取ckpt当中的值,我把 state_dict = torch.load(test_folder + "/" + c, map_location="cpu")["state_dict"] 改成 state_dict = torch.load(test_folder + "/" + c, map_location="cpu") 后就能顺利读值了。

接着我遇到 TypeError: mel() takes 0 positional arguments but 3 were given1 我修改了 mel_basis = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels) 解决。 但接着我遇到错误

raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'TransformerEncoderFusionLayer' object has no attribute 'self_attn'. Did you mean: 'self_attn1'?

这似乎显示ckpt存的模型跟预期的不同,但是我看其他人的issue都可以顺利的运行程式, ,我不明白为何我会遇到如此众多的问题,是不是我有哪个步骤有缺失,恳请老师指点我正确的执行方式。

DiLiangWU commented 8 months ago

您好, step1和step2关于load ckpt的部分是没问题的,我们在Readme的“Performance”处做了相关说明; step3关于mel的问题,应该属于librosa的版本问题,我们使用的librosa版本为0.9.2; 关于step4,不属于保存ckpt的名称问题: 我check了一下,保存的ckpt是没有问题的,并且可以顺利load进模型。您可以在train_dia.py中打印一下参数看看。

for name, param in state_dict.items():
    test_state[name] += param / len(ckpts)
    print(name, param)

image

您可以提供一下出现该error具体的代码位置吗? 比如调试一下看看报错的代码处

DTDwind commented 8 months ago

谢谢老师您快速的回复,虽然有点长,但我附上的了我完整的错误讯息,并且我查看模型参数发现确实有参数名为self_attn,我不确定是否因此造成影响。

老师能够顺利执行程式一度让我怀疑我下载到错误的模型,但是我检查档名为FS-EEND_ch_90_100epo_avg_model.ckpt,看起来没有问题,并且模型部份参数与老师所提供的截图对应。

我不明白问题出在哪里,还望老师帮忙,谢谢您。

image

Traceback (most recent call last):
  File "/mnt/HDD/HDD2/DTDwind/FS-EEND/train_dia.py", line 217, in <module>
    train(configs, gpus=setup.gpus, checkpoint_resume=setup.checkpoint_resume, test_folder=setup.test_from_folder)
  File "/mnt/HDD/HDD2/DTDwind/FS-EEND/train_dia.py", line 191, in train
    trainer.test(spk_dia_main)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 757, in test
    return call._call_and_handle_interrupt(
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 806, in _test_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1061, in _run
    results = self._run_stage()
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1137, in _run_stage
    return self._run_evaluate()
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1177, in _run_evaluate
    eval_loop_results = self._evaluation_loop.run()
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
  self.advance(*args, **kwargs)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 137, in advance
    output = self._evaluation_step(**kwargs)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 234, in _evaluati
on_step
    output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1443, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 399, in test_step
    return self.model.test_step(*args, **kwargs)
  File "/mnt/HDD/HDD2/DTDwind/FS-EEND/train/oln_tfm_enc_dec.py", line 185, in test_step
    preds, embs, attractors = self.model.test(feats, clip_lengths, self.max_spks + 2)
  File "/mnt/HDD/HDD2/DTDwind/FS-EEND/nnet/model/onl_tfm_enc_1dcnn_enc_linear_non_autoreg_pos_enc_l2norm.py", line 75, in test
    attractors = self.dec(emb, max_nspks)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/HDD/HDD2/DTDwind/FS-EEND/nnet/model/onl_tfm_enc_1dcnn_enc_linear_non_autoreg_pos_enc_l2norm.py", line 117, in forward
    attractors = self.attractor_decoder(attractors_init, t_mask)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
 File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 332, in forward
    batch_first = first_layer.self_attn.batch_first
  File "/home/DTDwind/.conda/envs/fs-eend/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'TransformerEncoderFusionLayer' object has no attribute 'self_attn'. Did you mean: 'self_attn1'?
Testing DataLoader 0:   0%|          | 0/41 [00:15<?, ?it/s]                                           
DTDwind commented 8 months ago

老师您好,我知道原因了,您在requirements.txt中设定torch>=1.13.0,但是其实您的程式不支援更高阶的torch版本,至少我用torch==2.x不行。这是直接执行pip install -r requirements.txt会安装的版本。在我重新安装新的torch==1.13.0后,程式已经可以正常执行了,还请老师帮忙更新一下requirements.txt

DiLiangWU commented 8 months ago

好的,感谢您提供的反馈,已更新requirements.txt。

DiLiangWU commented 7 months ago

@DTDwind 您好,我想请问下您在准备AMI数据集的时候,是使用 https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5c 这个脚本吗。我在执行的时候发现在local/ami_download.sh中manifest 和 license的url指向不存在的文件。请问您当时遇到这个问题了吗,以及您是怎么解决的呢,谢谢!

DTDwind commented 7 months ago

我是直接从这里下载的 https://groups.inf.ed.ac.uk/ami/download/ ,选择所有音档后选Headset mix,我印象中你说的脚本我执行过,但没遇到太多问题。我不确定两者结果是否一致,但我有跑 https://github.com/BUTSpeechFIT/VBx 验证我音档的正确性。 我在AHC上的结果可以跟论文上的结果对齐 DER 21.43。VBx 方法则有一点小误差。

DiLiangWU commented 7 months ago

好的,非常感谢您提供的信息