Audio-WestlakeU / FS-EEND

The official Pytorch implementation of "Frame-wise streaming end-to-end speaker diarization with non-autoregressive self-attention-based attractors". [ICASSP 2024]
MIT License
76 stars 4 forks source link
end-to-end frame-wise online-inference pytorch self-attention speaker-diarization

FS-EEND

The official Pytorch implementation of "Frame-wise streaming end-to-end speaker diarization with non-autoregressive self-attention-based attractors".

This work is accepted by ICASSP 2024.

version version python python

Paper :star_struck: | Issues :sweat_smile: | Lab :hear_no_evil: | Contact :kissing_heart:

Introduction

This work proposes a frame-wise online/streaming end-to-end neural diarization (FS-EEND) method in a frame-in-frame-out fashion. To frame-wisely detect a flexible number of speakers and extract/update their corresponding attractors, we propose to leverage a causal speaker embedding encoder and an online non-autoregressive self-attention-based attractor decoder. A look-ahead mechanism is adopted to allow leveraging some future frames for effectively detecting new speakers in real time and adaptively updating speaker attractors.

The proposed FS-EEND architecture

Get started

  1. Clone the FS-EEND codes by:
git clone https://github.com/Audio-WestlakeU/FS-EEND.git
  1. Prepare kaldi-style data by referring to here. Modify conf/xxx.yaml according to your own paths.

  2. Start training on simulated data by

python train_dia.py --configs conf/spk_onl_tfm_enc_dec_nonautoreg.yaml --gpus YOUR_DEVICE_ID
  1. Modify your pretrained model path in conf/spk_onl_tfm_enc_dec_nonautoreg_callhome.yaml.
  2. Finetune on CALLHOME data by
    python train_dia_fintn_ch.py --configs conf/spk_onl_tfm_enc_dec_nonautoreg_callhome.yaml --gpus YOUR_DEVICE_ID
  3. Inference by (# modify your own path to save predictions in test_step in train/oln_tfm_enc_decxxx.py.)
    python train_diaxxx.py --configs conf/xxx_infer.yaml --gpus YOUR_DEVICE_ID --test_from_folder YOUR_CKPT_SAVE_DIR
  4. Evaluation
    
    # generate speech activity probability (diarization results)
    cd visualize
    python gen_h5_output.py

calculate DERs

python metrics.py --configs conf/xxx_infer.yaml


# Performance
Please note we use Switchboard Cellular (Part 1 and 2) and 2005-2008 NIST Speaker Recognition Evaluation (SRE) to generate simulated data (including 4054 speakers).

| Dataset | DER(%) |ckpt|
| :--------: | :--: | :--: | 
| Simu1spk | 0.6 | [simu_avg_41_50epo.ckpt](https://drive.google.com/file/d/1JYr1zOxsHwQxIk9W4vwxzUfJFtaTQ02q/view?usp=sharing) |
| Simu2spk | 4.3 | same as above |
| Simu3spk | 9.8 | same as above |
| Simu4spk | 14.7 | same as above |
| CH2spk | 10.0 | [ch_avg_91_100epo.ckpt](https://drive.google.com/file/d/1i1Ow9IfPSwBRyRazY8-VX3z4ngDvSwx6/view?usp=sharing) |
| CH3spk | 15.3 | same as above |
| CH4spk | 21.8 | same as above |

The ckpts are the average of model parameters for the last 10 epochs.

If you want to check the performance of ckpt on CALLHOME:

python train_dia_fintn_ch.py --configs conf/spk_onl_tfm_enc_dec_nonautoreg_callhome_infer.yaml --gpus YOUR_DEVICE_ID, --test_from_folder YOUR_CKPT_SAVE_DIR

Note the modification of the code in train_dia_fintn_ch.py

ckpts = [x for x in all_files if (".ckpt" in x) and ("epoch" in x) and int(x.split("=")[1].split("-")[0])>=configs["log"]["start_epoch"] and int(x.split("=")[1].split("-")[0])<=configs["log"]["end_epoch"]]

state_dict = torch.load(test_folder + "/" + c, map_location="cpu")["state_dict"]

to

ckpts = [x for x in all_files if (".ckpt" in x)]

state_dict = torch.load(test_folder + "/" + c, map_location="cpu")


# Reference code
- <a href="https://github.com/hitachi-speech/EEND" target="_blank">EEND</a> 
- <a href="https://github.com/Xflick/EEND_PyTorch" target="_blank">EEND-Pytorch</a>

# Citation

If you want to cite this paper:

@misc{liang2023framewise, title={Frame-wise streaming end-to-end speaker diarization with non-autoregressive self-attention-based attractors}, author={Di Liang and Nian Shao and Xiaofei Li}, year={2023}, eprint={2309.13916}, archivePrefix={arXiv}, primaryClass={eess.AS} }