modelscope / FunASR

A Fundamental End-to-End Speech Recognition Toolkit and Open Source SOTA Pretrained Models, Supporting Speech Recognition, Voice Activity Detection, Text Post-processing etc.
https://www.funasr.com
Other
7.02k stars 747 forks source link

Paraformer热词版本finetune与onnx模型导出 #851

Closed ROAD2018 closed 1 year ago

ROAD2018 commented 1 year ago

要想使得对FunASR源码的修改在本地生效,则需要通过源码方式安装FunASR,即:

git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip3 install -e ./
# For the users in China, you could install with the command:
# pip3 install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple

1. 热词模型训练

首先将 damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch 预训练包下载到本地,通过 modelscope_args 加载预训练模型和相关配置文件,然后使用如下代码进行模型微调:

import os
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from modelscope.msdatasets.audio.asr_dataset import ASRDataset

def modelscope_finetune(params):
    if not os.path.exists(params.output_dir):
        os.makedirs(params.output_dir, exist_ok=True)
    # dataset split ["train", "validation"]
    ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr')
    kwargs = dict(
        model=params.model,
        data_dir=ds_dict,
        dataset_type=params.dataset_type,
        work_dir=params.output_dir,
        batch_bins=params.batch_bins,
        max_epoch=params.max_epoch,
        lr=params.lr)
    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
    trainer.train()

if __name__ == '__main__':
    from funasr.utils.modelscope_param import modelscope_args
    params = modelscope_args(model="./speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", data_path="./data")
    params.output_dir = "./checkpoint"       # 模型保存路径
    params.data_path = "./dataset/aishell1"     # 数据路径,可以为modelscope中已上传数据,也可以是本地数据
    params.dataset_type = "small"     # 小数据量设置small,若数据量大于1000小时,请使用large
    params.batch_bins = 2000      # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
    params.max_epoch = 50                 # 最大训练轮数
    params.lr = 0.00005                     # 设置学习率

    modelscope_finetune(params)

训练数据准备形式如下,包含:text文本数据和wav.scp音频数据,可以是中英文混合:

cat ./example_data/text
BAC009S0002W0122 而 对 楼 市 成 交 抑 制 作 用 最 大 的 限 购
BAC009S0002W0123 也 成 为 地 方 政 府 的 眼 中 钉
english_example_1 hello world
english_example_2 go swim 去 游 泳

cat ./example_data/wav.scp
BAC009S0002W0122 /mnt/data/wav/train/S0002/BAC009S0002W0122.wav
BAC009S0002W0123 /mnt/data/wav/train/S0002/BAC009S0002W0123.wav
english_example_1 /mnt/data/wav/train/S0002/english_example_1.wav
english_example_2 /mnt/data/wav/train/S0002/english_example_2.wav

2. 热词模型导出onnx模型

目前版本缺少热词onnx导出相关代码,为此需要在相应位置处增加代码,具体如下: 注意:目前热词是在训练时候从训练数据中动态选择文本作为热词进行训练,训练代码用的是funasr/models/ e2e_asr_paraformer.py脚本里面的 ContextualParaformer 类;而模型onnx导出的时候实际用的是funasr/models/ e2e_asr_contextual_paraformer.py脚本里面的 NeatContextualParaformer类。为了正确导出onnx模型需要增加一些脚本。 热词版本导出的思路:将Paraformer主体模型和热词向量提取模型分隔成两个模型,这样做有两个好处:1)热词不用每次都重新提取热词向量,可以在初始化时候一次性将所有热词都提取出来,推理的时候如果有新的热词再额外增加,可以减少计算量,同时保持热词提取的灵活性;2)由于热词向量提取使用的是LSTM,batch_size只能是1,无法与Paraformer主体模型的batch_size匹配,所以需要单独把热词向量提取模块分离出来。

修改主要是:funasr/export文件夹下的文件,具体的有:

1)export_model.py 脚本修改,由于热词版本的模型会导出两个onnx模型,所以会修改 get_model 函数,使其返回的是一个列表,即:

      self.export_config["model_name"] = "model"
      models = get_model(
          model,
          self.export_config,
      )
      print('models:', models)
      print('models type:', type(models))
      if len(models) > 1:
          asr_model, bias_model = models[0], models[1]
          asr_model.eval()
          # self._export_onnx(model, verbose, export_dir)
          if self.onnx:
              self._export_onnx(asr_model, verbose, export_dir)
          else:
              self._export_torchscripts(asr_model, verbose, export_dir)
          print("output {} dir: {}".format('asr_model', export_dir))

          bias_model.eval()
          # self._export_onnx(model, verbose, export_dir)
          if self.onnx:
              self._export_onnx(bias_model, verbose, export_dir)
          else:
              self._export_torchscripts(bias_model, verbose, export_dir)

          print("output {} dir: {}".format('bias_model', export_dir))

      else:
          models = models[0]
          models.eval()
          # self._export_onnx(model, verbose, export_dir)
          if self.onnx:
              self._export_onnx(models, verbose, export_dir)
          else:
              self._export_torchscripts(models, verbose, export_dir)

          print("output dir: {}".format(export_dir))

2)funasr/export/models/init.py 脚本修改里面 get_model 函数: 其中,修改了 funasr/export/models/e2e_asr_paraformer.py 函数,脚本中增加了 NeatContextualParaformer 和 ContextualBiasEncoder 两个类函数来导出Paraformer主体模型和LSTM热词模型。

from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import NeatContextualParaformer as ContextualParaformer_export
from funasr.export.models.e2e_asr_paraformer import ContextualBiasEncoder as ContextualBiasEncoder_export

from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export
from funasr.train.abs_model import PunctuationModel
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export

def get_model(model, export_config=None):
    if isinstance(model, BiCifParaformer):
        return [BiCifParaformer_export(model, **export_config)]
    elif isinstance(model, NeatContextualParaformer):
        print('export NeatContextualParaformer ...')
        return [ContextualParaformer_export(model, **export_config), ContextualBiasEncoder_export(model, **export_config)]
    elif isinstance(model, Paraformer):
        print('export Paraformer ...')
        return [Paraformer_export(model, **export_config)]
    elif isinstance(model, E2EVadModel):
        print('export E2EVadModel ...')
        return [E2EVadModel_export(model, **export_config)]
    elif isinstance(model, PunctuationModel):
        print('export PunctuationModel ...')
        if isinstance(model.punc_model, TargetDelayTransformer):
            print('export TargetDelayTransformer ...')
            return [CT_Transformer_export(model.punc_model, **export_config)]
        elif isinstance(model.punc_model, VadRealtimeTransformer):
            print('export VadRealtimeTransformer ...')
            return [CT_Transformer_VadRealtime_export(model.punc_model, **export_config)]

3)修改 funasr/export/models/e2e_asr_paraformer.py 脚本,脚本中增加NeatContextualParaformer 和 ContextualBiasEncoder 两个类来导出Paraformer主体模型和LSTM热词模型。

class NeatContextualParaformer(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """

    def __init__(
            self,
            model,
            max_seq_len=512,
            feats_dim=560,
            context_dim=512,
            model_name='model',
            **kwargs,
    ):
        super().__init__()
        onnx = False
        if "onnx" in kwargs:
            onnx = kwargs["onnx"]
        if isinstance(model.encoder, SANMEncoder):
            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
        elif isinstance(model.encoder, ConformerEncoder):
            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
        else:
            logging.warning("Unsupported encoder type to export.")

        if isinstance(model.predictor, CifPredictorV2):
            self.predictor = CifPredictorV2_export(model.predictor)

        # self.bias_embed = model.bias_embed if hasattr(model, 'bias_embed') else None
        # if hasattr(model, 'bias_encoder'):   # if isinstance(model.bias_encoder, torch.nn.LSTM):
            # logging.warning("enable bias encoder sampling and contextual training with bias_encoder")
            # self.bias_encoder = model.bias_encoder
        # else:
            # logging.warning("Unsupport bias encoder type: {}".format(model.bias_encoder))

        if isinstance(model.decoder, ContextualParaformerDecoder):
            self.decoder = ContextualParaformerDecoder_export(model.decoder, onnx=onnx)
        else:
            logging.warning("Unsupported decoder type to export.")

        self.feats_dim = feats_dim
        self.context_dim = context_dim
        self.model_name = model_name

        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)

    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            hotword_embeds: torch.Tensor,
            clas_scale: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        # batch = to_device(batch, device=self.device)

        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
        pre_token_length = pre_token_length.floor().type(torch.int32)

        # -1. bias encoder
        contextual_info = hotword_embeds.unsqueeze(0).repeat(pre_acoustic_embeds.shape[0], 1, 1)
        # print('contextual_info:', contextual_info.shape)

        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, contextual_info=contextual_info, clas_scale=clas_scale)
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        # sample_ids = decoder_out.argmax(dim=-1)

        return decoder_out, pre_token_length

    def get_dummy_inputs(self):
        speech = torch.randn(2, 30, self.feats_dim)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        hotword_embeds = torch.randn(30, self.context_dim)
        clas_scale = torch.tensor([1.0], dtype=torch.float32)
        return (speech, speech_lengths, hotword_embeds, clas_scale)

    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt", hot_file: str = "/mnt/workspace/hotword.txt"):
        import numpy as np
        fbank = np.loadtxt(txt_file)
        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
        hotword = np.loadtxt(hot_file)
        hotword_embeds = torch.from_numpy(hotword.astype(np.float32))
        clas_scale = np.array([hotword[:,-1], ], dtype=np.float32)
        clas_scale = torch.from_numpy(clas_scale.astype(np.float32))
        return (speech, speech_lengths, hotword_embeds, clas_scale)

    def get_input_names(self):
        return ['speech', 'speech_lengths', 'hotword_embeds', 'clas_scale']

    def get_output_names(self):
        return ['logits', 'token_num']

    def get_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'hotword_embeds': {
                0: 'hotword_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }

class ContextualBiasEncoder(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """

    def __init__(
            self,
            model,
            max_seq_len=512,
            feats_dim=560,
            context_dim=512,
            model_name='model',
            **kwargs,
    ):
        super().__init__()
        onnx = False
        if "onnx" in kwargs:
            onnx = kwargs["onnx"]

        self.bias_embed = model.bias_embed if hasattr(model, 'bias_embed') else None
        if hasattr(model, 'bias_encoder'):   # if isinstance(model.bias_encoder, torch.nn.LSTM):
            logging.warning("enable bias encoder sampling and contextual training with bias_encoder")
            self.bias_encoder = model.bias_encoder
        else:
            logging.warning("Unsupport bias encoder type: {}".format(model.bias_encoder))

        self.feats_dim = feats_dim
        self.context_dim = context_dim
        self.model_name = model_name + '_contextual_bias_encoder'

    def forward(
            self,
            hotword: torch.Tensor,
            # h0: torch.Tensor,
            # c0: torch.Tensor,
    ):
        # -1. bias encoder
        hotword = hotword.unsqueeze(0)
        hw_embed = self.bias_embed(hotword)
        hw_embed, (_, _) = self.bias_encoder(hw_embed)
        print('hw_embed:', hw_embed.shape)
        hotword_embed = hw_embed[0,-1,:].squeeze(0)
        print('hotword_embed:', hotword_embed.shape)

        return hotword_embed

    def get_dummy_inputs(self):
        hotword = torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).type(torch.int32)
        #hotword = torch.Tensor([1]).type(torch.int32)
        # h0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32)
        # c0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32)
        #return (hotword, h0, c0)
        return hotword

    def get_dummy_inputs_txt(self, hot_file: str = "/mnt/workspace/hotword.txt"):
        import numpy as np
        hotword = np.loadtxt(hot_file)
        hotword = torch.from_numpy(hotword.astype(np.int32))
        # h0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32)
        # c0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32)
        # return (hotword, h0, c0)
        return hotword

    def get_input_names(self):
        #return ['hotword', 'h0', 'c0']
        return ['hotword']

    def get_output_names(self):
        #return ['hotword_embed', 'hn', 'cn']
        return ['hotword_embed']

    def get_dynamic_axes(self):
        return {
            'hotword': {
                0: 'hotword_length',
            },
        }

4)新增funasr/export/models/modules/contextual_decoder_layer.py 脚本。脚本中两个类:ContextualDecoderLayer和 ContextualBiasDecoder 用于解码端热词所需的Attention计算。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
from torch import nn
from typing import Tuple

class ContextualDecoderLayer(nn.Module):

    def __init__(
        self,
        model
    ):
        super().__init__()
        self.self_attn = model.self_attn
        self.src_attn = model.src_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
        self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
        self.size = model.size

    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):

        if isinstance(tgt, Tuple):
            tgt, _ = tgt
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)

        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
            x_self_attn = x

        residual = x
        if self.src_attn is not None:
            x = self.norm3(x)
            x = self.src_attn(x, memory, memory_mask)
            x_src_attn = x

        x = residual + x

        return x, tgt_mask, x_self_attn, x_src_attn

class ContextualBiasDecoder(nn.Module):

    def __init__(
        self,
        model
    ):
        super().__init__()
        self.src_attn = model.src_attn
        self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
        self.size = model.size

    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):

        x = tgt
        if self.src_attn is not None:
            residual = x
            x = self.norm3(x)
            x = self.src_attn(x, memory, memory_mask)

        return x, tgt_mask, memory, memory_mask, cache

5)新增 funasr/export/models/decoder/contextual_decoder.py脚本,用于导出解码器。

import os

import torch
import torch.nn as nn

from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask

from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
from funasr.modules.attention import MultiHeadedAttentionCrossAtt
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export

from funasr.models.decoder.contextual_decoder import ContextualDecoderLayer as ContextualDecoderLayer
from funasr.models.decoder.contextual_decoder import ContextualBiasDecoder as ContextualBiasDecoder
from funasr.export.models.modules.contextual_decoder_layer import ContextualDecoderLayer as ContextualDecoderLayer_export
from funasr.export.models.modules.contextual_decoder_layer import ContextualBiasDecoder as ContextualBiasDecoder_export

class ContextualParaformerSANMDecoder(nn.Module):
    def __init__(self, model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True,):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)

        for i, d in enumerate(self.model.decoders):
            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANM_export(d)

        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                    d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANM_export(d)

        for i, d in enumerate(self.model.decoders3):
            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
            self.model.decoders3[i] = DecoderLayerSANM_export(d)

        if self.model.last_decoder is not None:
            if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM):
                print('export model.last_decoder, feed_forward ...')
                self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward)
            if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder):
                print('export model.last_decoder, self_attn ...')
                self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn)
            if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt):
                print('export model.last_decoder, src_attn ...')
                self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn)
            self.model.last_decoder = ContextualDecoderLayer_export(self.model.last_decoder)

        if self.model.bias_decoder is not None:
            if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt):
                print('export model.bias_decoder, src_attn ...')
                self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn)
            self.model.bias_decoder = ContextualBiasDecoder_export(self.model.bias_decoder)

        self.model.bias_output = self.model.bias_output if hasattr(model, 'bias_output') else None

        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name

    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0

        return mask_3d_btd, mask_4d_bhlt

    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        contextual_info: torch.Tensor,
        clas_scale: float = 1.0
    ):

        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _  = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]

        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask  = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]

        #print('tgt:',tgt.shape, 'tgt_mask:',tgt_mask.shape, 'memory:',memory.shape, 'memory_mask:',memory_mask.shape)

        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
            x, tgt_mask, memory, memory_mask
        )

        #print('\n>>> export last_decoder ....')
        _, _, x_self_attn, x_src_attn = self.model.last_decoder(
            x, tgt_mask, memory, memory_mask
        )

        #print('x_self_attn:',x_self_attn.shape, 'x_src_attn:',x_src_attn.shape)

        # contextual paraformer related
        #contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
        contextual_length = contextual_info.shape[1].repeat(hs_pad.shape[0])
        contextual_mask = self.make_pad_mask(contextual_length)
        _, contextual_mask = self.prepare_mask(contextual_mask)
        #contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
        cx, tgt_mask, _, _, _ = self.model.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)

        if self.model.bias_output is not None:
            x = torch.cat([x_src_attn, cx*clas_scale], dim=2)
            x = self.model.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
            x = x_self_attn + x

        if self.model.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
            x, tgt_mask, memory, memory_mask
        )
        x = self.after_norm(x)
        x = self.output_layer(x)

        return x, ys_in_lens

    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
            for _ in range(cache_num)
        ]
        if hasattr(self.model, 'last_decoder'):
            cache.append(torch.zeros((1, self.model.last_decoder.size, self.model.last_decoder.self_attn.kernel_size)))
        if hasattr(self.model, 'bias_decoder'):
            cache.append(torch.zeros((1, self.model.bias_decoder.size, self.model.bias_decoder.src_attn.kernel_size)))
        return (tgt, memory, pre_acoustic_embeds, cache)

    def is_optimizable(self):
        return True

    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return_list = ['tgt', 'memory', 'pre_acoustic_embeds'] \
               + ['cache_%d' % i for i in range(cache_num)]
        if hasattr(self.model, 'last_decoder'):
            return_list = return_list + ['cache_last']
        if hasattr(self.model, 'bias_decoder'):
            return_list = return_list + ['cache_bias']   
        return return_list

    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return_list = ['y'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
        if hasattr(self.model, 'last_decoder'):
            return_list = return_list + ['cache_last']
        if hasattr(self.model, 'bias_decoder'):
            return_list = return_list + ['cache_bias']
        return return_list

    def get_dynamic_axes(self):
        ret = {
            'tgt': {
                0: 'tgt_batch',
                1: 'tgt_length'
            },
            'memory': {
                0: 'memory_batch',
                1: 'memory_length'
            },
            'pre_acoustic_embeds': {
                0: 'acoustic_embeds_batch',
                1: 'acoustic_embeds_length',
            }
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update({
            'cache_%d' % d: {
                0: 'cache_%d_batch' % d,
                2: 'cache_%d_length' % d
            }
            for d in range(cache_num)
        })
        ret.update({
            'contextual_info': {
                0: 'contextual_info_batch',
                1: 'contextual_info_length'
            }
        })
        return ret

    def get_model_config(self, path):
        return {
            "dec_type": "XformerDecoder",
            "model_path": os.path.join(path, f'{self.model_name}.onnx'),
            "n_layers": len(self.model.decoders) + len(self.model.decoders2) + len(self.model.last_decoder),
            "odim": self.model.decoders[0].size
        }

6)新增 funasr/runtime/python/onnxruntime/funasr_onnx/contextual_paraformer_bin.py 用于热词推理。

# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)

import os.path
from pathlib import Path
from typing import List, Union, Tuple

import copy
import librosa
import numpy as np
import re

from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                          OrtInferSession, TokenIDConverter, get_logger,
                          read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx

logging = get_logger()

class Contextual_Paraformer():
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 plot_timestamp_to: str = "",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 ):

        if not Path(model_dir).exists():
            from modelscope.hub.snapshot_download import snapshot_download
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)

        model_file = os.path.join(model_dir, 'model.onnx')
        if quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        if not os.path.exists(model_file):
            print("model_file {} .onnx is not exist, begin to export onnx".format(model_file))
            raise "model_file {} .onnx is not exist, begin to export onnx".format(model_file)

        context_model_file = os.path.join(model_dir, 'model_contextual_bias_encoder.onnx')
        if quantize:
            context_model_file = os.path.join(model_dir, 'model_contextual_bias_encoder_quant.onnx')
        if not os.path.exists(context_model_file):
            raise "model_file {} .onnx is not exist, begin to export onnx".format(context_model_file)

        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)

        self.converter = TokenIDConverter(config['token_list'])
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = batch_size
        self.plot_timestamp_to = plot_timestamp_to
        if "predictor_bias" in config['model_conf'].keys():
            self.pred_bias = config['model_conf']['predictor_bias']
        else:
            self.pred_bias = 0

        # contextual
        hotword_file = os.path.join(model_dir, 'hotword.txt')
        with open(hotword_file, 'r') as f:
            hotword_list = f.readlines()
        #print('context_model_file:', context_model_file)
        #print('hotword_list:', hotword_list)
        self.context_ort_infer = OrtInferSession(context_model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.hotword_embeds = self.extract_hotword_embed(hotword_list)
        print('self.hotword_embeds:', len(self.hotword_embeds), self.hotword_embeds[0].shape)

    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], 
                       hotwords: List[str], clas_scale: np.ndarray, **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        hotword_embeds = self.load_hotword(hotwords)
        clas_scale_r = np.array([1.0], dtype=np.float32)
        if not isinstance(clas_scale, type(clas_scale_r)):
            clas_scale = np.array([clas_scale], dtype=np.float32)
            print('clas_scale:', type(clas_scale))
        print('hotword_embeds:', hotword_embeds.shape)
        asr_res = []
        for beg_idx in range(0, waveform_nums, self.batch_size):

            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            print('feats:', feats.shape)
            print('feats_len:', feats_len)
            print('hotword_embeds:', hotword_embeds.shape)
            print('clas_scale:', clas_scale)
            try:
                outputs = self.infer(feats, feats_len, hotword_embeds, clas_scale)
                am_scores, valid_token_lens = outputs[0], outputs[1]
                print('am_scores:', am_scores.shape)
                if len(outputs) == 4:
                    # for BiCifParaformer Inference
                    us_alphas, us_peaks = outputs[2], outputs[3]
                else:
                    us_alphas, us_peaks = None, None
            except ONNXRuntimeError:
                #logging.warning(traceback.format_exc())
                logging.warning("input wav is silence or noise")
                preds = ['']
            else:
                preds = self.decode(am_scores, valid_token_lens)
                if us_peaks is None:
                    for pred in preds:
                        pred = sentence_postprocess(pred)
                        asr_res.append({'preds': pred})
                else:
                    for pred, us_peaks_ in zip(preds, us_peaks):
                        raw_tokens = pred
                        timestamp, timestamp_raw = time_stamp_lfr6_onnx(us_peaks_, copy.copy(raw_tokens))
                        text_proc, timestamp_proc, _ = sentence_postprocess(raw_tokens, timestamp_raw)
                        # logging.warning(timestamp)
                        if len(self.plot_timestamp_to):
                            self.plot_wave_timestamp(waveform_list[0], timestamp, self.plot_timestamp_to)
                        asr_res.append({'preds': text_proc, 'timestamp': timestamp_proc, "raw_tokens": raw_tokens})
        return asr_res

    def plot_wave_timestamp(self, wav, text_timestamp, dest):
        # TODO: Plot the wav and timestamp results with matplotlib
        import matplotlib
        matplotlib.use('Agg')
        matplotlib.rc("font", family='Alibaba PuHuiTi')  # set it to a font that your system supports
        import matplotlib.pyplot as plt
        fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320)
        ax2 = ax1.twinx()
        ax2.set_ylim([0, 2.0])
        # plot waveform
        ax1.set_ylim([-0.3, 0.3])
        time = np.arange(wav.shape[0]) / 16000
        ax1.plot(time, wav/wav.max()*0.3, color='gray', alpha=0.4)
        # plot lines and text
        for (char, start, end) in text_timestamp:
            ax1.vlines(start, -0.3, 0.3, ls='--')
            ax1.vlines(end, -0.3, 0.3, ls='--')
            x_adj = 0.045 if char != '<sil>' else 0.12
            ax1.text((start + end) * 0.5 - x_adj, 0, char)
        # plt.legend()
        plotname = "{}/timestamp.png".format(dest)
        plt.savefig(plotname, bbox_inches='tight')

    def load_data(self,
                  wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
        def load_wav(path: str) -> np.ndarray:
            waveform, _ = librosa.load(path, sr=fs)
            return waveform

        if isinstance(wav_content, np.ndarray):
            return [wav_content]

        if isinstance(wav_content, str):
            return [load_wav(wav_content)]

        if isinstance(wav_content, list):
            return [load_wav(path) for path in wav_content]

        raise TypeError(
            f'The type of {wav_content} is not in [str, np.ndarray, list]')

    def load_hotword(self, hotwords: List[str]) -> np.ndarray:
        hotwords_embeds_list = self.hotword_embeds
        new_hotwords = []
        if len(hotwords) > 0:
            new_hotwords = self.extract_hotword_embed(hotwords)
            hotwords_embeds_list.extend(new_hotwords)
        hotwords_embeds = np.stack(hotwords_embeds_list, axis=0)
        return hotwords_embeds

    def extract_feat(self,
                     waveform_list: List[np.ndarray]
                     ) -> Tuple[np.ndarray, np.ndarray]:
        feats, feats_len = [], []
        for waveform in waveform_list:
            speech, _ = self.frontend.fbank(waveform)
            feat, feat_len = self.frontend.lfr_cmvn(speech)
            feats.append(feat)
            feats_len.append(feat_len)

        feats = self.pad_feats(feats, np.max(feats_len))
        feats_len = np.array(feats_len).astype(np.int32)
        return feats, feats_len

    def extract_hotword_embed(self, hotword_list):
        tokenids_list, hw_embed_list = [], []
        for hotword in hotword_list:
            hotword = hotword.strip()
            tokenids = self.converter.tokens2ids(hotword)
            tokenids_list.append(tokenids_list)
            tokenids = np.array(tokenids, dtype=np.int32)
            hw_embed = self.context_ort_infer([tokenids])[0]
            hw_embed_list.append(hw_embed)
        return hw_embed_list

    @staticmethod
    def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
        def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
            pad_width = ((0, max_feat_len - cur_len), (0, 0))
            return np.pad(feat, pad_width, 'constant', constant_values=0)

        feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
        feats = np.array(feat_res).astype(np.float32)
        return feats

    def infer(self, feats: np.ndarray,
              feats_len: np.ndarray,
              hotword_embed: np.ndarray, 
              clas_scale: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer([feats, feats_len, hotword_embed, clas_scale])
        return outputs

    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
        return [self.decode_one(am_score, token_num)
                for am_score, token_num in zip(am_scores, token_nums)]

    def decode_one(self,
                   am_score: np.ndarray,
                   valid_token_num: int) -> List[str]:
        yseq = am_score.argmax(axis=-1)
        score = am_score.max(axis=-1)
        score = np.sum(score, axis=-1)

        # pad with mask tokens to ensure compatibility with sos/eos tokens
        # asr_model.sos:1  asr_model.eos:2
        yseq = np.array([1] + yseq.tolist() + [2])
        hyp = Hypothesis(yseq=yseq, score=score)

        # remove sos/eos and get results
        last_pos = -1
        token_int = hyp.yseq[1:last_pos].tolist()

        # remove blank symbol id, which is assumed to be 0
        token_int = list(filter(lambda x: x not in (0, 2), token_int))

        # Change integer-ids to tokens
        token = self.converter.ids2tokens(token_int)
        token = token[:valid_token_num-self.pred_bias]
        # texts = sentence_postprocess(token)
        return token
hzfei commented 1 year ago

惊喜要来了?!

LauraGPT commented 1 year ago

要想使得对FunASR源码的修改在本地生效,则需要通过源码方式安装FunASR,即:

git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip3 install -e ./
# For the users in China, you could install with the command:
# pip3 install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple

1. 热词模型训练

首先将 damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch 预训练包下载到本地,通过 modelscope_args 加载预训练模型和相关配置文件,然后使用如下代码进行模型微调:

import os
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from modelscope.msdatasets.audio.asr_dataset import ASRDataset

def modelscope_finetune(params):
    if not os.path.exists(params.output_dir):
        os.makedirs(params.output_dir, exist_ok=True)
    # dataset split ["train", "validation"]
    ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr')
    kwargs = dict(
        model=params.model,
        data_dir=ds_dict,
        dataset_type=params.dataset_type,
        work_dir=params.output_dir,
        batch_bins=params.batch_bins,
        max_epoch=params.max_epoch,
        lr=params.lr)
    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
    trainer.train()

if __name__ == '__main__':
    from funasr.utils.modelscope_param import modelscope_args
    params = modelscope_args(model="./speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", data_path="./data")
    params.output_dir = "./checkpoint"       # 模型保存路径
    params.data_path = "./dataset/aishell1"     # 数据路径,可以为modelscope中已上传数据,也可以是本地数据
    params.dataset_type = "small"     # 小数据量设置small,若数据量大于1000小时,请使用large
    params.batch_bins = 2000      # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
    params.max_epoch = 50                 # 最大训练轮数
    params.lr = 0.00005                     # 设置学习率

    modelscope_finetune(params)

训练数据准备形式如下,包含:text文本数据和wav.scp音频数据,可以是中英文混合:

cat ./example_data/text
BAC009S0002W0122 而 对 楼 市 成 交 抑 制 作 用 最 大 的 限 购
BAC009S0002W0123 也 成 为 地 方 政 府 的 眼 中 钉
english_example_1 hello world
english_example_2 go swim 去 游 泳

cat ./example_data/wav.scp
BAC009S0002W0122 /mnt/data/wav/train/S0002/BAC009S0002W0122.wav
BAC009S0002W0123 /mnt/data/wav/train/S0002/BAC009S0002W0123.wav
english_example_1 /mnt/data/wav/train/S0002/english_example_1.wav
english_example_2 /mnt/data/wav/train/S0002/english_example_2.wav

2. 热词模型导出onnx模型

目前版本缺少热词onnx导出相关代码,为此需要在相应位置处增加代码,具体如下: 注意:目前热词是在训练时候从训练数据中动态选择文本作为热词进行训练,训练代码用的是funasr/models/ e2e_asr_paraformer.py脚本里面的 ContextualParaformer 类;而模型onnx导出的时候实际用的是funasr/models/ e2e_asr_contextual_paraformer.py脚本里面的 NeatContextualParaformer类。为了正确导出onnx模型需要增加一些脚本。 热词版本导出的思路:将Paraformer主体模型和热词向量提取模型分隔成两个模型,这样做有两个好处:1)热词不用每次都重新提取热词向量,可以在初始化时候一次性将所有热词都提取出来,推理的时候如果有新的热词再额外增加,可以减少计算量,同时保持热词提取的灵活性;2)由于热词向量提取使用的是LSTM,batch_size只能是1,无法与Paraformer主体模型的batch_size匹配,所以需要单独把热词向量提取模块分离出来。

修改主要是:funasr/export文件夹下的文件,具体的有:

1)export_model.py 脚本修改,由于热词版本的模型会导出两个onnx模型,所以会修改 get_model 函数,使其返回的是一个列表,即:

    self.export_config["model_name"] = "model"
    models = get_model(
        model,
        self.export_config,
    )
    print('models:', models)
    print('models type:', type(models))
    if len(models) > 1:
        asr_model, bias_model = models[0], models[1]
        asr_model.eval()
        # self._export_onnx(model, verbose, export_dir)
        if self.onnx:
            self._export_onnx(asr_model, verbose, export_dir)
        else:
            self._export_torchscripts(asr_model, verbose, export_dir)
        print("output {} dir: {}".format('asr_model', export_dir))

        bias_model.eval()
        # self._export_onnx(model, verbose, export_dir)
        if self.onnx:
            self._export_onnx(bias_model, verbose, export_dir)
        else:
            self._export_torchscripts(bias_model, verbose, export_dir)

        print("output {} dir: {}".format('bias_model', export_dir))

    else:
        models = models[0]
        models.eval()
        # self._export_onnx(model, verbose, export_dir)
        if self.onnx:
            self._export_onnx(models, verbose, export_dir)
        else:
            self._export_torchscripts(models, verbose, export_dir)

        print("output dir: {}".format(export_dir))

2)funasr/export/models/init.py 脚本修改里面 get_model 函数: 其中,修改了 funasr/export/models/e2e_asr_paraformer.py 函数,脚本中增加了 NeatContextualParaformer 和 ContextualBiasEncoder 两个类函数来导出Paraformer主体模型和LSTM热词模型。

from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import NeatContextualParaformer as ContextualParaformer_export
from funasr.export.models.e2e_asr_paraformer import ContextualBiasEncoder as ContextualBiasEncoder_export

from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export
from funasr.train.abs_model import PunctuationModel
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export

def get_model(model, export_config=None):
    if isinstance(model, BiCifParaformer):
        return [BiCifParaformer_export(model, **export_config)]
    elif isinstance(model, NeatContextualParaformer):
        print('export NeatContextualParaformer ...')
        return [ContextualParaformer_export(model, **export_config), ContextualBiasEncoder_export(model, **export_config)]
    elif isinstance(model, Paraformer):
        print('export Paraformer ...')
        return [Paraformer_export(model, **export_config)]
    elif isinstance(model, E2EVadModel):
        print('export E2EVadModel ...')
        return [E2EVadModel_export(model, **export_config)]
    elif isinstance(model, PunctuationModel):
        print('export PunctuationModel ...')
        if isinstance(model.punc_model, TargetDelayTransformer):
            print('export TargetDelayTransformer ...')
            return [CT_Transformer_export(model.punc_model, **export_config)]
        elif isinstance(model.punc_model, VadRealtimeTransformer):
            print('export VadRealtimeTransformer ...')
            return [CT_Transformer_VadRealtime_export(model.punc_model, **export_config)]

3)修改 funasr/export/models/e2e_asr_paraformer.py 脚本,脚本中增加NeatContextualParaformer 和 ContextualBiasEncoder 两个类来导出Paraformer主体模型和LSTM热词模型。

class NeatContextualParaformer(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """

    def __init__(
            self,
            model,
            max_seq_len=512,
            feats_dim=560,
            context_dim=512,
            model_name='model',
            **kwargs,
    ):
        super().__init__()
        onnx = False
        if "onnx" in kwargs:
            onnx = kwargs["onnx"]
        if isinstance(model.encoder, SANMEncoder):
            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
        elif isinstance(model.encoder, ConformerEncoder):
            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
        else:
            logging.warning("Unsupported encoder type to export.")

        if isinstance(model.predictor, CifPredictorV2):
            self.predictor = CifPredictorV2_export(model.predictor)

        # self.bias_embed = model.bias_embed if hasattr(model, 'bias_embed') else None
        # if hasattr(model, 'bias_encoder'):   # if isinstance(model.bias_encoder, torch.nn.LSTM):
            # logging.warning("enable bias encoder sampling and contextual training with bias_encoder")
            # self.bias_encoder = model.bias_encoder
        # else:
            # logging.warning("Unsupport bias encoder type: {}".format(model.bias_encoder))

        if isinstance(model.decoder, ContextualParaformerDecoder):
            self.decoder = ContextualParaformerDecoder_export(model.decoder, onnx=onnx)
        else:
            logging.warning("Unsupported decoder type to export.")

        self.feats_dim = feats_dim
        self.context_dim = context_dim
        self.model_name = model_name

        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)

    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            hotword_embeds: torch.Tensor,
            clas_scale: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        # batch = to_device(batch, device=self.device)

        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
        pre_token_length = pre_token_length.floor().type(torch.int32)

        # -1. bias encoder
        contextual_info = hotword_embeds.unsqueeze(0).repeat(pre_acoustic_embeds.shape[0], 1, 1)
        # print('contextual_info:', contextual_info.shape)

        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, contextual_info=contextual_info, clas_scale=clas_scale)
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        # sample_ids = decoder_out.argmax(dim=-1)

        return decoder_out, pre_token_length

    def get_dummy_inputs(self):
        speech = torch.randn(2, 30, self.feats_dim)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        hotword_embeds = torch.randn(30, self.context_dim)
        clas_scale = torch.tensor([1.0], dtype=torch.float32)
        return (speech, speech_lengths, hotword_embeds, clas_scale)

    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt", hot_file: str = "/mnt/workspace/hotword.txt"):
        import numpy as np
        fbank = np.loadtxt(txt_file)
        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
        hotword = np.loadtxt(hot_file)
        hotword_embeds = torch.from_numpy(hotword.astype(np.float32))
        clas_scale = np.array([hotword[:,-1], ], dtype=np.float32)
        clas_scale = torch.from_numpy(clas_scale.astype(np.float32))
        return (speech, speech_lengths, hotword_embeds, clas_scale)

    def get_input_names(self):
        return ['speech', 'speech_lengths', 'hotword_embeds', 'clas_scale']

    def get_output_names(self):
        return ['logits', 'token_num']

    def get_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'hotword_embeds': {
                0: 'hotword_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }

class ContextualBiasEncoder(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """

    def __init__(
            self,
            model,
            max_seq_len=512,
            feats_dim=560,
            context_dim=512,
            model_name='model',
            **kwargs,
    ):
        super().__init__()
        onnx = False
        if "onnx" in kwargs:
            onnx = kwargs["onnx"]

        self.bias_embed = model.bias_embed if hasattr(model, 'bias_embed') else None
        if hasattr(model, 'bias_encoder'):   # if isinstance(model.bias_encoder, torch.nn.LSTM):
            logging.warning("enable bias encoder sampling and contextual training with bias_encoder")
            self.bias_encoder = model.bias_encoder
        else:
            logging.warning("Unsupport bias encoder type: {}".format(model.bias_encoder))

        self.feats_dim = feats_dim
        self.context_dim = context_dim
        self.model_name = model_name + '_contextual_bias_encoder'

    def forward(
            self,
            hotword: torch.Tensor,
            # h0: torch.Tensor,
            # c0: torch.Tensor,
    ):
        # -1. bias encoder
        hotword = hotword.unsqueeze(0)
        hw_embed = self.bias_embed(hotword)
        hw_embed, (_, _) = self.bias_encoder(hw_embed)
        print('hw_embed:', hw_embed.shape)
        hotword_embed = hw_embed[0,-1,:].squeeze(0)
        print('hotword_embed:', hotword_embed.shape)

        return hotword_embed

    def get_dummy_inputs(self):
        hotword = torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).type(torch.int32)
        #hotword = torch.Tensor([1]).type(torch.int32)
        # h0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32)
        # c0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32)
        #return (hotword, h0, c0)
        return hotword

    def get_dummy_inputs_txt(self, hot_file: str = "/mnt/workspace/hotword.txt"):
        import numpy as np
        hotword = np.loadtxt(hot_file)
        hotword = torch.from_numpy(hotword.astype(np.int32))
        # h0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32)
        # c0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32)
        # return (hotword, h0, c0)
        return hotword

    def get_input_names(self):
        #return ['hotword', 'h0', 'c0']
        return ['hotword']

    def get_output_names(self):
        #return ['hotword_embed', 'hn', 'cn']
        return ['hotword_embed']

    def get_dynamic_axes(self):
        return {
            'hotword': {
                0: 'hotword_length',
            },
        }

4)新增funasr/export/models/modules/contextual_decoder_layer.py 脚本。脚本中两个类:ContextualDecoderLayer和 ContextualBiasDecoder 用于解码端热词所需的Attention计算。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
from torch import nn
from typing import Tuple

class ContextualDecoderLayer(nn.Module):

    def __init__(
        self,
        model
    ):
        super().__init__()
        self.self_attn = model.self_attn
        self.src_attn = model.src_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
        self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
        self.size = model.size

    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):

        if isinstance(tgt, Tuple):
            tgt, _ = tgt
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)

        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
            x_self_attn = x

        residual = x
        if self.src_attn is not None:
            x = self.norm3(x)
            x = self.src_attn(x, memory, memory_mask)
            x_src_attn = x

        x = residual + x

        return x, tgt_mask, x_self_attn, x_src_attn

class ContextualBiasDecoder(nn.Module):

    def __init__(
        self,
        model
    ):
        super().__init__()
        self.src_attn = model.src_attn
        self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
        self.size = model.size

    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):

        x = tgt
        if self.src_attn is not None:
            residual = x
            x = self.norm3(x)
            x = self.src_attn(x, memory, memory_mask)

        return x, tgt_mask, memory, memory_mask, cache

5)新增 funasr/export/models/decoder/contextual_decoder.py脚本,用于导出解码器。

import os

import torch
import torch.nn as nn

from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask

from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
from funasr.modules.attention import MultiHeadedAttentionCrossAtt
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export

from funasr.models.decoder.contextual_decoder import ContextualDecoderLayer as ContextualDecoderLayer
from funasr.models.decoder.contextual_decoder import ContextualBiasDecoder as ContextualBiasDecoder
from funasr.export.models.modules.contextual_decoder_layer import ContextualDecoderLayer as ContextualDecoderLayer_export
from funasr.export.models.modules.contextual_decoder_layer import ContextualBiasDecoder as ContextualBiasDecoder_export

class ContextualParaformerSANMDecoder(nn.Module):
    def __init__(self, model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True,):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)

        for i, d in enumerate(self.model.decoders):
            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANM_export(d)

        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                    d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANM_export(d)

        for i, d in enumerate(self.model.decoders3):
            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
            self.model.decoders3[i] = DecoderLayerSANM_export(d)

        if self.model.last_decoder is not None:
            if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM):
                print('export model.last_decoder, feed_forward ...')
                self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward)
            if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder):
                print('export model.last_decoder, self_attn ...')
                self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn)
            if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt):
                print('export model.last_decoder, src_attn ...')
                self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn)
            self.model.last_decoder = ContextualDecoderLayer_export(self.model.last_decoder)

        if self.model.bias_decoder is not None:
            if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt):
                print('export model.bias_decoder, src_attn ...')
                self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn)
            self.model.bias_decoder = ContextualBiasDecoder_export(self.model.bias_decoder)

        self.model.bias_output = self.model.bias_output if hasattr(model, 'bias_output') else None

        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name

    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0

        return mask_3d_btd, mask_4d_bhlt

    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        contextual_info: torch.Tensor,
        clas_scale: float = 1.0
    ):

        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _  = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]

        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask  = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]

        #print('tgt:',tgt.shape, 'tgt_mask:',tgt_mask.shape, 'memory:',memory.shape, 'memory_mask:',memory_mask.shape)

        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
            x, tgt_mask, memory, memory_mask
        )

        #print('\n>>> export last_decoder ....')
        _, _, x_self_attn, x_src_attn = self.model.last_decoder(
            x, tgt_mask, memory, memory_mask
        )

        #print('x_self_attn:',x_self_attn.shape, 'x_src_attn:',x_src_attn.shape)

        # contextual paraformer related
        #contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
        contextual_length = contextual_info.shape[1].repeat(hs_pad.shape[0])
        contextual_mask = self.make_pad_mask(contextual_length)
        _, contextual_mask = self.prepare_mask(contextual_mask)
        #contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
        cx, tgt_mask, _, _, _ = self.model.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)

        if self.model.bias_output is not None:
            x = torch.cat([x_src_attn, cx*clas_scale], dim=2)
            x = self.model.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
            x = x_self_attn + x

        if self.model.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
            x, tgt_mask, memory, memory_mask
        )
        x = self.after_norm(x)
        x = self.output_layer(x)

        return x, ys_in_lens

    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
            for _ in range(cache_num)
        ]
        if hasattr(self.model, 'last_decoder'):
            cache.append(torch.zeros((1, self.model.last_decoder.size, self.model.last_decoder.self_attn.kernel_size)))
        if hasattr(self.model, 'bias_decoder'):
            cache.append(torch.zeros((1, self.model.bias_decoder.size, self.model.bias_decoder.src_attn.kernel_size)))
        return (tgt, memory, pre_acoustic_embeds, cache)

    def is_optimizable(self):
        return True

    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return_list = ['tgt', 'memory', 'pre_acoustic_embeds'] \
               + ['cache_%d' % i for i in range(cache_num)]
        if hasattr(self.model, 'last_decoder'):
            return_list = return_list + ['cache_last']
        if hasattr(self.model, 'bias_decoder'):
            return_list = return_list + ['cache_bias']   
        return return_list

    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return_list = ['y'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
        if hasattr(self.model, 'last_decoder'):
            return_list = return_list + ['cache_last']
        if hasattr(self.model, 'bias_decoder'):
            return_list = return_list + ['cache_bias']
        return return_list

    def get_dynamic_axes(self):
        ret = {
            'tgt': {
                0: 'tgt_batch',
                1: 'tgt_length'
            },
            'memory': {
                0: 'memory_batch',
                1: 'memory_length'
            },
            'pre_acoustic_embeds': {
                0: 'acoustic_embeds_batch',
                1: 'acoustic_embeds_length',
            }
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update({
            'cache_%d' % d: {
                0: 'cache_%d_batch' % d,
                2: 'cache_%d_length' % d
            }
            for d in range(cache_num)
        })
        ret.update({
            'contextual_info': {
                0: 'contextual_info_batch',
                1: 'contextual_info_length'
            }
        })
        return ret

    def get_model_config(self, path):
        return {
            "dec_type": "XformerDecoder",
            "model_path": os.path.join(path, f'{self.model_name}.onnx'),
            "n_layers": len(self.model.decoders) + len(self.model.decoders2) + len(self.model.last_decoder),
            "odim": self.model.decoders[0].size
        }

6)新增 funasr/runtime/python/onnxruntime/funasr_onnx/contextual_paraformer_bin.py 用于热词推理。

# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)

import os.path
from pathlib import Path
from typing import List, Union, Tuple

import copy
import librosa
import numpy as np
import re

from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                          OrtInferSession, TokenIDConverter, get_logger,
                          read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx

logging = get_logger()

class Contextual_Paraformer():
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 plot_timestamp_to: str = "",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 ):

        if not Path(model_dir).exists():
            from modelscope.hub.snapshot_download import snapshot_download
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)

        model_file = os.path.join(model_dir, 'model.onnx')
        if quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        if not os.path.exists(model_file):
            print("model_file {} .onnx is not exist, begin to export onnx".format(model_file))
            raise "model_file {} .onnx is not exist, begin to export onnx".format(model_file)

        context_model_file = os.path.join(model_dir, 'model_contextual_bias_encoder.onnx')
        if quantize:
            context_model_file = os.path.join(model_dir, 'model_contextual_bias_encoder_quant.onnx')
        if not os.path.exists(context_model_file):
            raise "model_file {} .onnx is not exist, begin to export onnx".format(context_model_file)

        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)

        self.converter = TokenIDConverter(config['token_list'])
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = batch_size
        self.plot_timestamp_to = plot_timestamp_to
        if "predictor_bias" in config['model_conf'].keys():
            self.pred_bias = config['model_conf']['predictor_bias']
        else:
            self.pred_bias = 0

        # contextual
        hotword_file = os.path.join(model_dir, 'hotword.txt')
        with open(hotword_file, 'r') as f:
            hotword_list = f.readlines()
        #print('context_model_file:', context_model_file)
        #print('hotword_list:', hotword_list)
        self.context_ort_infer = OrtInferSession(context_model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.hotword_embeds = self.extract_hotword_embed(hotword_list)
        print('self.hotword_embeds:', len(self.hotword_embeds), self.hotword_embeds[0].shape)

    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], 
                       hotwords: List[str], clas_scale: np.ndarray, **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        hotword_embeds = self.load_hotword(hotwords)
        clas_scale_r = np.array([1.0], dtype=np.float32)
        if not isinstance(clas_scale, type(clas_scale_r)):
            clas_scale = np.array([clas_scale], dtype=np.float32)
            print('clas_scale:', type(clas_scale))
        print('hotword_embeds:', hotword_embeds.shape)
        asr_res = []
        for beg_idx in range(0, waveform_nums, self.batch_size):

            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            print('feats:', feats.shape)
            print('feats_len:', feats_len)
            print('hotword_embeds:', hotword_embeds.shape)
            print('clas_scale:', clas_scale)
            try:
                outputs = self.infer(feats, feats_len, hotword_embeds, clas_scale)
                am_scores, valid_token_lens = outputs[0], outputs[1]
                print('am_scores:', am_scores.shape)
                if len(outputs) == 4:
                    # for BiCifParaformer Inference
                    us_alphas, us_peaks = outputs[2], outputs[3]
                else:
                    us_alphas, us_peaks = None, None
            except ONNXRuntimeError:
                #logging.warning(traceback.format_exc())
                logging.warning("input wav is silence or noise")
                preds = ['']
            else:
                preds = self.decode(am_scores, valid_token_lens)
                if us_peaks is None:
                    for pred in preds:
                        pred = sentence_postprocess(pred)
                        asr_res.append({'preds': pred})
                else:
                    for pred, us_peaks_ in zip(preds, us_peaks):
                        raw_tokens = pred
                        timestamp, timestamp_raw = time_stamp_lfr6_onnx(us_peaks_, copy.copy(raw_tokens))
                        text_proc, timestamp_proc, _ = sentence_postprocess(raw_tokens, timestamp_raw)
                        # logging.warning(timestamp)
                        if len(self.plot_timestamp_to):
                            self.plot_wave_timestamp(waveform_list[0], timestamp, self.plot_timestamp_to)
                        asr_res.append({'preds': text_proc, 'timestamp': timestamp_proc, "raw_tokens": raw_tokens})
        return asr_res

    def plot_wave_timestamp(self, wav, text_timestamp, dest):
        # TODO: Plot the wav and timestamp results with matplotlib
        import matplotlib
        matplotlib.use('Agg')
        matplotlib.rc("font", family='Alibaba PuHuiTi')  # set it to a font that your system supports
        import matplotlib.pyplot as plt
        fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320)
        ax2 = ax1.twinx()
        ax2.set_ylim([0, 2.0])
        # plot waveform
        ax1.set_ylim([-0.3, 0.3])
        time = np.arange(wav.shape[0]) / 16000
        ax1.plot(time, wav/wav.max()*0.3, color='gray', alpha=0.4)
        # plot lines and text
        for (char, start, end) in text_timestamp:
            ax1.vlines(start, -0.3, 0.3, ls='--')
            ax1.vlines(end, -0.3, 0.3, ls='--')
            x_adj = 0.045 if char != '<sil>' else 0.12
            ax1.text((start + end) * 0.5 - x_adj, 0, char)
        # plt.legend()
        plotname = "{}/timestamp.png".format(dest)
        plt.savefig(plotname, bbox_inches='tight')

    def load_data(self,
                  wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
        def load_wav(path: str) -> np.ndarray:
            waveform, _ = librosa.load(path, sr=fs)
            return waveform

        if isinstance(wav_content, np.ndarray):
            return [wav_content]

        if isinstance(wav_content, str):
            return [load_wav(wav_content)]

        if isinstance(wav_content, list):
            return [load_wav(path) for path in wav_content]

        raise TypeError(
            f'The type of {wav_content} is not in [str, np.ndarray, list]')

    def load_hotword(self, hotwords: List[str]) -> np.ndarray:
        hotwords_embeds_list = self.hotword_embeds
        new_hotwords = []
        if len(hotwords) > 0:
            new_hotwords = self.extract_hotword_embed(hotwords)
            hotwords_embeds_list.extend(new_hotwords)
        hotwords_embeds = np.stack(hotwords_embeds_list, axis=0)
        return hotwords_embeds

    def extract_feat(self,
                     waveform_list: List[np.ndarray]
                     ) -> Tuple[np.ndarray, np.ndarray]:
        feats, feats_len = [], []
        for waveform in waveform_list:
            speech, _ = self.frontend.fbank(waveform)
            feat, feat_len = self.frontend.lfr_cmvn(speech)
            feats.append(feat)
            feats_len.append(feat_len)

        feats = self.pad_feats(feats, np.max(feats_len))
        feats_len = np.array(feats_len).astype(np.int32)
        return feats, feats_len

    def extract_hotword_embed(self, hotword_list):
        tokenids_list, hw_embed_list = [], []
        for hotword in hotword_list:
            hotword = hotword.strip()
            tokenids = self.converter.tokens2ids(hotword)
            tokenids_list.append(tokenids_list)
            tokenids = np.array(tokenids, dtype=np.int32)
            hw_embed = self.context_ort_infer([tokenids])[0]
            hw_embed_list.append(hw_embed)
        return hw_embed_list

    @staticmethod
    def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
        def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
            pad_width = ((0, max_feat_len - cur_len), (0, 0))
            return np.pad(feat, pad_width, 'constant', constant_values=0)

        feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
        feats = np.array(feat_res).astype(np.float32)
        return feats

    def infer(self, feats: np.ndarray,
              feats_len: np.ndarray,
              hotword_embed: np.ndarray, 
              clas_scale: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer([feats, feats_len, hotword_embed, clas_scale])
        return outputs

    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
        return [self.decode_one(am_score, token_num)
                for am_score, token_num in zip(am_scores, token_nums)]

    def decode_one(self,
                   am_score: np.ndarray,
                   valid_token_num: int) -> List[str]:
        yseq = am_score.argmax(axis=-1)
        score = am_score.max(axis=-1)
        score = np.sum(score, axis=-1)

        # pad with mask tokens to ensure compatibility with sos/eos tokens
        # asr_model.sos:1  asr_model.eos:2
        yseq = np.array([1] + yseq.tolist() + [2])
        hyp = Hypothesis(yseq=yseq, score=score)

        # remove sos/eos and get results
        last_pos = -1
        token_int = hyp.yseq[1:last_pos].tolist()

        # remove blank symbol id, which is assumed to be 0
        token_int = list(filter(lambda x: x not in (0, 2), token_int))

        # Change integer-ids to tokens
        token = self.converter.ids2tokens(token_int)
        token = token[:valid_token_num-self.pred_bias]
        # texts = sentence_postprocess(token)
        return token

We greatly appreciate your effort. Your code is excellent and completely accurate. The funasr official now supports the export and inference of paraformer hotword model in ONNX format. We welcome you to try it out and provide us with feedback. The python version: funasr-onnx 0.2.1 and the c++ version: funasr-runtime

R1ckShi commented 1 year ago

thanks for your contribution again, it looks like we have done the same work at the same time, maybe next time you can raise pull request directly if there's any update obtained.