wenet-e2e / wespeaker

Research and Production Oriented Speaker Verification, Recognition and Diarization Toolkit
Apache License 2.0
707 stars 116 forks source link

Exporting ONNX Model with Final Class Probabilities Instead of Embeddings #340

Closed didi1233 closed 2 months ago

didi1233 commented 3 months ago

Dear WeSpeaker Team,

I am trying to use WeSpeaker for a classification task. I have trained a three-class model using ResNet and a Linear classifier. However, I would like the exported ONNX model to output the final class probabilities instead of the embeddings. I attempted to use the following code, but the inference results are very incorrect. Could you please help me identify the problem? Thank you.

from __future__ import print_function

import argparse
import numpy as np
import torch
import torch.nn as nn
import yaml

from speaker_model import get_speaker_model
from checkpoint import load_checkpoint
from resnet import ResNet18

class Linear(nn.Module):
    """
    The linear transform for simple softmax loss
    """

    def __init__(self, emb_dim=256, class_num=3):
        super(Linear, self).__init__()

        self.trans = nn.Sequential(nn.BatchNorm1d(emb_dim),
                                   nn.ReLU(inplace=True),
                                   nn.Linear(emb_dim, class_num))

    def forward(self, input, label=None):
        out = self.trans(input)
        return out

class CombinedModel(nn.Module):
    def __init__(self, resnet_model, classifier):
        super(CombinedModel, self).__init__()
        self.resnet_model = resnet_model
        self.classifier = classifier

    def forward(self, x):
        _, embeddings = self.resnet_model(x)
        output = self.classifier(embeddings)
        return output

def get_args():
    parser = argparse.ArgumentParser(description='export your script model')
    parser.add_argument('--config', required=True, help='config file')
    parser.add_argument('--checkpoint', required=True, help='checkpoint model')
    parser.add_argument('--output_model', required=True, help='output file')
    parser.add_argument('--mean_vec',
                        required=False,
                        default=None,
                        help='mean vector')
    args = parser.parse_args()
    return args

def main():
    args = get_args()

    with open(args.config, 'r') as fin:
        configs = yaml.load(fin, Loader=yaml.FullLoader)
    resnet_model = ResNet18(feat_dim=80, embed_dim=256, two_emb_layer=False)
    classifier = Linear(emb_dim=256, class_num=3)
    model = CombinedModel(resnet_model, classifier)
    load_checkpoint(model, args.checkpoint)
    model.eval()

    feat_dim = configs['model_args'].get('feat_dim', 80)
    if 'feature_args' in configs:  # deprecated IO
        num_frms = configs['feature_args'].get('num_frms', 200)
    else:  # UIO
        num_frms = configs['dataset_args'].get('num_frms', 200)

    dummy_input = torch.ones(1, num_frms, feat_dim)
    torch.onnx.export(model,
                      dummy_input,
                      args.output_model,
                      do_constant_folding=True,
                      verbose=False,
                      opset_version=14,
                      input_names=['feats'],
                      output_names=['embs'],
                      dynamic_axes={
                          'feats': {
                              0: 'B',
                              1: 'T'
                          },
                          'embs': {
                              0: 'B'
                          }
                      })

if __name__ == '__main__':
    main()
JiJiJiang commented 3 months ago

Thank you for your question. We only save the model part before the speaker embedding layer in WeSpeaker, while the classifier (from embedding to speaker label) is not saved during training. For your case, you should modify the codes in save_checkpoint and save the classifier, otherwise the classifier would be with random parameters and thus the final results would be random too.

JiJiJiang commented 2 months ago

@didi1233 Have you solved this problem? My answer was not correct above. We did save the whole speaker model including the classifier you need, but during extracting, we did not load the classifier part.

For your case, if you train the model with the naive softmax loss, I think it should work well using our training pipeline. Did you find why it did not work?

didi1233 commented 2 months ago

@didi1233 Have you solved this problem? My answer was not correct above. We did save the whole speaker model including the classifier you need, but during extracting, we did not load the classifier part.

For your case, if you train the model with the naive softmax loss, I think it should work well using our training pipeline. Did you find why it did not work?

Hello, I have resolved this issue. Wespeaker does indeed save the classifier part. My problem was due to incorrectly using the 'class CombinedModel', which caused the parameters of the saved classification layer to fail to load correctly due to node naming issues. The correct example code is as follows:

import copy import os

import fire
import kaldiio
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import torch
import torch.nn as nn
import torch.nn.functional as F
#from wespeaker.dataset.dataset import Dataset
from speaker_model import get_speaker_model
from checkpoint import load_checkpoint
#from wespeaker.utils.utils import parse_config_or_kwargs, validate_path
from utils import parse_config_or_kwargs
from projections import get_projection

def compute_fbank(wav_path,
                  num_mel_bins=40,
                  frame_length=25,
                  frame_shift=10,
                  dither=0.0):
    """ Extract fbank, simlilar to the one in wespeaker.dataset.processor,
        While integrating the wave reading and CMN.
    """
    waveform, sample_rate = torchaudio.load(wav_path)
    waveform = waveform * (1 << 15)
    mat = kaldi.fbank(waveform,
                      num_mel_bins=num_mel_bins,
                      frame_length=frame_length,
                      frame_shift=frame_shift,
                      dither=dither,
                      sample_frequency=sample_rate,
                      window_type='hamming',
                      use_energy=False)
    # CMN, without CVN
    mat = mat - torch.mean(mat, dim=0)
    return mat

def extract(config='config.yaml', **kwargs):
    configs = parse_config_or_kwargs(config, **kwargs)
    batch_size = 1
    num_workers = 1

    torch.backends.cudnn.benchmark = False
    model_path = 'model_10.pt'
    model = get_speaker_model(configs['model'])(**configs['model_args'])
    configs['projection_args']['embed_dim'] = configs['model_args']['embed_dim']
    configs['projection_args']['num_class'] = 2 #your class num!!!
    configs['projection_args']['do_lm'] = configs.get('do_lm', False)
    projection = get_projection(configs['projection_args'])
    model.add_module("projection", projection)

    device = torch.device("cuda")
    model.to(device).eval() 
    load_checkpoint(model, model_path)
    print(model)

    with torch.no_grad():
        feats = compute_fbank('your.wav')
        feats = feats.unsqueeze(0)  # add batch dimension 
        features = feats.float().to(device)
        #print(features.shape)
        outputs = model(features)
        embeds = outputs[-1] if isinstance(outputs, tuple) else outputs
        embeds = embeds  # (B,F)
        outputs = projection(embeds).cpu().detach()
        print(outputs)

if __name__ == '__main__':
    fire.Fire(extract)

Additionally, it's necessary to set 'label = None' in 'https://github.com/wenet-e2e/wespeaker/blob/master/wespeaker/models/projections.py#L483'.

Attached is the binary classification model for noise and speech, along with the config.yaml file. I hope this will be helpful to those who need it.

Best regards! example.zip

JiJiJiang commented 2 months ago

I see. Thanks for your answer.