wenet-e2e / wespeaker

Research and Production Oriented Speaker Verification, Recognition and Diarization Toolkit
Apache License 2.0
598 stars 104 forks source link

[cli] support campplus_200k and eres2net_200k models of damo #281

Closed cdliang11 closed 4 months ago

cdliang11 commented 4 months ago

usags:

wespeaker --eres2net --task embedding --audio_file ../1000003_0f90da0d.wav --output_file test.txt
wespeaker --campplus --task embedding --audio_file ../1000003_0f90da0d.wav --output_file test.txt
JiJiJiang commented 4 months ago

Good job! BTW, have you ever compared the extracted embedding of the same audio by using wespeaker cli and 3d-speaker inference codes?

cdliang11 commented 4 months ago

Good job! BTW, have you ever compared the extracted embedding of the same audio by using wespeaker cli and 3d-speaker inference codes?

The outputs of both are the same:

3d-speaker-campplus
python speakerlab/bin/infer_sv.py --model_id damo/speech_campplus_sv_zh-cn_16k-common --wavs ../1000003_0f90da0d.wav
[[-0.4085187   0.50557995  0.51636136  0.1142143   0.2404558  -1.1339788
  -1.0937457  -2.8552299   0.06838006  0.13499458  0.71838266  0.44210368
   2.2673156  -1.0265858  -0.50908566 -0.11793047  1.3293266  -1.6357919
   1.4727495  -0.07005835 -1.51254     2.8411684  -0.8805871   1.1905364
   0.3311209  -0.2748037  -0.5990227  -0.5974755   1.4530792   1.1028372
  -1.5573118  -0.18140262 -0.32122046  0.5014688  -0.20990026 -1.3896389
  -1.4280965   0.84823036  1.4982125   0.5026547   0.4770487   0.4092163
  -0.08150318  0.64153427 -1.3073081   1.5903411  -0.55855393  1.7382021
  -0.7740348   0.5523105  -0.50692123  1.0383366   0.9142423   0.4274579
   0.28643748  0.5127196  -0.13859004 -1.4316454   1.71023    -1.0759138
   0.850508   -2.1463156   0.9324371  -0.6005027  -0.32279783 -1.0592827
  -2.8073277  -0.72237927  0.9778304  -0.28658646  0.76317513 -1.2475644
   1.29144    -1.0815903   0.6496      1.2780393  -0.03912735  2.327661
  -0.07493195 -1.0741758   0.14549455 -0.89600277 -1.5814674  -0.4504004
  -0.61274517  0.42109856  0.14153782 -0.19801521  0.6075355  -0.16368878
   1.0157286  -0.6853902   1.2261169  -0.37317002 -1.7594169  -0.31674182
   0.01597187  1.1843703   1.4816221  -0.6894153   0.3778354  -0.6784406
   0.562997    0.5129908   1.262701   -0.31300706 -0.7393719   0.6164677
  -0.47587308  0.6412958   1.8476795   1.1121718  -0.43268675  0.22637391
  -1.0736793  -0.55764985  0.4124422   0.18827613  0.4426846  -0.7585005
   0.23592305 -0.5652432   0.02683749  0.7558515   0.72929215 -0.5218993
   0.95271045 -0.8250192  -0.67063344 -0.25632948 -0.19230154 -0.56856936
   1.0763328  -0.9233105  -1.4971652   0.67828107 -0.54156303  0.3079476
   0.50454146 -0.9912834  -0.88064206 -0.06452227  0.10343862 -1.0667944
   0.80994976 -1.7790337  -0.6505701  -0.5829258  -0.86960715 -0.1688903
   0.4326726   1.0910878   0.15102836 -0.84867907 -0.43227464  1.0685582
  -0.7399875   1.5167844  -0.48675457  0.10172446  0.20288429  0.1092979
  -0.4115819   0.6395146  -0.1783466   0.50276726  0.63910335 -1.8752003
   0.00609797 -1.4072868   0.7114489   0.596313    0.24544685  0.70587957
   1.4148142  -0.10190775  1.2026131   0.9365389   0.65656376  0.49392414
   0.8144008   0.11556987  1.1940337  -1.5439909   1.5463395  -0.44434774
  -0.70085263  0.60402256  1.0413061   0.46588537 -1.1118577   0.8619109 ]]
cdliang11 commented 4 months ago

Using the following code to convert the damo/eres2net model to wespeaker format:

import torch

def convert_model(path, output_path):
    states = torch.load(path, map_location='cpu')
    adapter_layers = ["layer3.0", "layer3.1", "layer3.2", "layer3.3", "layer3.4", "layer3.5",
                      "layer4.0", "layer4.1", "layer4.2"]
    for key in adapter_layers:
        states[f"{key}.conv2_1.weight"] = states.pop(f"{key}.convs.0.weight")
        states[f"{key}.bn2_1.weight"] = states.pop(f"{key}.bns.0.weight")
        states[f"{key}.bn2_1.bias"] = states.pop(f"{key}.bns.0.bias")
        states[f"{key}.bn2_1.running_mean"] = states.pop(f"{key}.bns.0.running_mean")
        states[f"{key}.bn2_1.running_var"] = states.pop(f"{key}.bns.0.running_var")
        states[f"{key}.bn2_1.num_batches_tracked"] = states.pop(f"{key}.bns.0.num_batches_tracked")

        states[f"{key}.convs.0.weight"] = states.pop(f"{key}.convs.1.weight")
        states[f"{key}.bns.0.weight"] = states.pop(f"{key}.bns.1.weight")
        states[f"{key}.bns.0.bias"] = states.pop(f"{key}.bns.1.bias")
        states[f"{key}.bns.0.running_mean"] = states.pop(f"{key}.bns.1.running_mean")
        states[f"{key}.bns.0.running_var"] = states.pop(f"{key}.bns.1.running_var")
        states[f"{key}.bns.0.num_batches_tracked"] = states.pop(f"{key}.bns.1.num_batches_tracked")

        states[f"{key}.convs.1.weight"] = states.pop(f"{key}.convs.2.weight")
        states[f"{key}.bns.1.weight"] = states.pop(f"{key}.bns.2.weight")
        states[f"{key}.bns.1.bias"] = states.pop(f"{key}.bns.2.bias")
        states[f"{key}.bns.1.running_mean"] = states.pop(f"{key}.bns.2.running_mean")
        states[f"{key}.bns.1.running_var"] = states.pop(f"{key}.bns.2.running_var")
        states[f"{key}.bns.1.num_batches_tracked"] = states.pop(f"{key}.bns.2.num_batches_tracked")

    torch.save(states, output_path)

if __name__ == "__main__":
    convert_model("/Users/user01/code/wespeaker-cli/pre_model/eres2net_commom/avg_model.pt",
                  "/Users/user01/code/wespeaker-cli/pre_model/eres2net_commom/avg_model_convert.pt")