Closed ROAD2018 closed 1 year ago
惊喜要来了?!
要想使得对FunASR源码的修改在本地生效,则需要通过源码方式安装FunASR,即:
git clone https://github.com/alibaba/FunASR.git && cd FunASR pip3 install -e ./ # For the users in China, you could install with the command: # pip3 install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
1. 热词模型训练
首先将 damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch 预训练包下载到本地,通过 modelscope_args 加载预训练模型和相关配置文件,然后使用如下代码进行模型微调:
import os from modelscope.metainfo import Trainers from modelscope.trainers import build_trainer from modelscope.msdatasets.audio.asr_dataset import ASRDataset def modelscope_finetune(params): if not os.path.exists(params.output_dir): os.makedirs(params.output_dir, exist_ok=True) # dataset split ["train", "validation"] ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr') kwargs = dict( model=params.model, data_dir=ds_dict, dataset_type=params.dataset_type, work_dir=params.output_dir, batch_bins=params.batch_bins, max_epoch=params.max_epoch, lr=params.lr) trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs) trainer.train() if __name__ == '__main__': from funasr.utils.modelscope_param import modelscope_args params = modelscope_args(model="./speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", data_path="./data") params.output_dir = "./checkpoint" # 模型保存路径 params.data_path = "./dataset/aishell1" # 数据路径,可以为modelscope中已上传数据,也可以是本地数据 params.dataset_type = "small" # 小数据量设置small,若数据量大于1000小时,请使用large params.batch_bins = 2000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒, params.max_epoch = 50 # 最大训练轮数 params.lr = 0.00005 # 设置学习率 modelscope_finetune(params)
训练数据准备形式如下,包含:text文本数据和wav.scp音频数据,可以是中英文混合:
cat ./example_data/text BAC009S0002W0122 而 对 楼 市 成 交 抑 制 作 用 最 大 的 限 购 BAC009S0002W0123 也 成 为 地 方 政 府 的 眼 中 钉 english_example_1 hello world english_example_2 go swim 去 游 泳 cat ./example_data/wav.scp BAC009S0002W0122 /mnt/data/wav/train/S0002/BAC009S0002W0122.wav BAC009S0002W0123 /mnt/data/wav/train/S0002/BAC009S0002W0123.wav english_example_1 /mnt/data/wav/train/S0002/english_example_1.wav english_example_2 /mnt/data/wav/train/S0002/english_example_2.wav
2. 热词模型导出onnx模型
目前版本缺少热词onnx导出相关代码,为此需要在相应位置处增加代码,具体如下: 注意:目前热词是在训练时候从训练数据中动态选择文本作为热词进行训练,训练代码用的是funasr/models/ e2e_asr_paraformer.py脚本里面的 ContextualParaformer 类;而模型onnx导出的时候实际用的是funasr/models/ e2e_asr_contextual_paraformer.py脚本里面的 NeatContextualParaformer类。为了正确导出onnx模型需要增加一些脚本。 热词版本导出的思路:将Paraformer主体模型和热词向量提取模型分隔成两个模型,这样做有两个好处:1)热词不用每次都重新提取热词向量,可以在初始化时候一次性将所有热词都提取出来,推理的时候如果有新的热词再额外增加,可以减少计算量,同时保持热词提取的灵活性;2)由于热词向量提取使用的是LSTM,batch_size只能是1,无法与Paraformer主体模型的batch_size匹配,所以需要单独把热词向量提取模块分离出来。
修改主要是:funasr/export文件夹下的文件,具体的有:
1)export_model.py 脚本修改,由于热词版本的模型会导出两个onnx模型,所以会修改 get_model 函数,使其返回的是一个列表,即:
self.export_config["model_name"] = "model" models = get_model( model, self.export_config, ) print('models:', models) print('models type:', type(models)) if len(models) > 1: asr_model, bias_model = models[0], models[1] asr_model.eval() # self._export_onnx(model, verbose, export_dir) if self.onnx: self._export_onnx(asr_model, verbose, export_dir) else: self._export_torchscripts(asr_model, verbose, export_dir) print("output {} dir: {}".format('asr_model', export_dir)) bias_model.eval() # self._export_onnx(model, verbose, export_dir) if self.onnx: self._export_onnx(bias_model, verbose, export_dir) else: self._export_torchscripts(bias_model, verbose, export_dir) print("output {} dir: {}".format('bias_model', export_dir)) else: models = models[0] models.eval() # self._export_onnx(model, verbose, export_dir) if self.onnx: self._export_onnx(models, verbose, export_dir) else: self._export_torchscripts(models, verbose, export_dir) print("output dir: {}".format(export_dir))
2)funasr/export/models/init.py 脚本修改里面 get_model 函数: 其中,修改了 funasr/export/models/e2e_asr_paraformer.py 函数,脚本中增加了 NeatContextualParaformer 和 ContextualBiasEncoder 两个类函数来导出Paraformer主体模型和LSTM热词模型。
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer from funasr.export.models.e2e_asr_paraformer import NeatContextualParaformer as ContextualParaformer_export from funasr.export.models.e2e_asr_paraformer import ContextualBiasEncoder as ContextualBiasEncoder_export from funasr.models.e2e_vad import E2EVadModel from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export from funasr.models.target_delay_transformer import TargetDelayTransformer from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export from funasr.train.abs_model import PunctuationModel from funasr.models.vad_realtime_transformer import VadRealtimeTransformer from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export def get_model(model, export_config=None): if isinstance(model, BiCifParaformer): return [BiCifParaformer_export(model, **export_config)] elif isinstance(model, NeatContextualParaformer): print('export NeatContextualParaformer ...') return [ContextualParaformer_export(model, **export_config), ContextualBiasEncoder_export(model, **export_config)] elif isinstance(model, Paraformer): print('export Paraformer ...') return [Paraformer_export(model, **export_config)] elif isinstance(model, E2EVadModel): print('export E2EVadModel ...') return [E2EVadModel_export(model, **export_config)] elif isinstance(model, PunctuationModel): print('export PunctuationModel ...') if isinstance(model.punc_model, TargetDelayTransformer): print('export TargetDelayTransformer ...') return [CT_Transformer_export(model.punc_model, **export_config)] elif isinstance(model.punc_model, VadRealtimeTransformer): print('export VadRealtimeTransformer ...') return [CT_Transformer_VadRealtime_export(model.punc_model, **export_config)]
3)修改 funasr/export/models/e2e_asr_paraformer.py 脚本,脚本中增加NeatContextualParaformer 和 ContextualBiasEncoder 两个类来导出Paraformer主体模型和LSTM热词模型。
class NeatContextualParaformer(nn.Module): """ Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition https://arxiv.org/abs/2206.08317 """ def __init__( self, model, max_seq_len=512, feats_dim=560, context_dim=512, model_name='model', **kwargs, ): super().__init__() onnx = False if "onnx" in kwargs: onnx = kwargs["onnx"] if isinstance(model.encoder, SANMEncoder): self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) elif isinstance(model.encoder, ConformerEncoder): self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx) else: logging.warning("Unsupported encoder type to export.") if isinstance(model.predictor, CifPredictorV2): self.predictor = CifPredictorV2_export(model.predictor) # self.bias_embed = model.bias_embed if hasattr(model, 'bias_embed') else None # if hasattr(model, 'bias_encoder'): # if isinstance(model.bias_encoder, torch.nn.LSTM): # logging.warning("enable bias encoder sampling and contextual training with bias_encoder") # self.bias_encoder = model.bias_encoder # else: # logging.warning("Unsupport bias encoder type: {}".format(model.bias_encoder)) if isinstance(model.decoder, ContextualParaformerDecoder): self.decoder = ContextualParaformerDecoder_export(model.decoder, onnx=onnx) else: logging.warning("Unsupported decoder type to export.") self.feats_dim = feats_dim self.context_dim = context_dim self.model_name = model_name if onnx: self.make_pad_mask = MakePadMask(max_seq_len, flip=False) else: self.make_pad_mask = sequence_mask(max_seq_len, flip=False) def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, hotword_embeds: torch.Tensor, clas_scale: torch.Tensor, ): # a. To device batch = {"speech": speech, "speech_lengths": speech_lengths} # batch = to_device(batch, device=self.device) enc, enc_len = self.encoder(**batch) mask = self.make_pad_mask(enc_len)[:, None, :] pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask) pre_token_length = pre_token_length.floor().type(torch.int32) # -1. bias encoder contextual_info = hotword_embeds.unsqueeze(0).repeat(pre_acoustic_embeds.shape[0], 1, 1) # print('contextual_info:', contextual_info.shape) decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, contextual_info=contextual_info, clas_scale=clas_scale) decoder_out = torch.log_softmax(decoder_out, dim=-1) # sample_ids = decoder_out.argmax(dim=-1) return decoder_out, pre_token_length def get_dummy_inputs(self): speech = torch.randn(2, 30, self.feats_dim) speech_lengths = torch.tensor([6, 30], dtype=torch.int32) hotword_embeds = torch.randn(30, self.context_dim) clas_scale = torch.tensor([1.0], dtype=torch.float32) return (speech, speech_lengths, hotword_embeds, clas_scale) def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt", hot_file: str = "/mnt/workspace/hotword.txt"): import numpy as np fbank = np.loadtxt(txt_file) fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32) speech = torch.from_numpy(fbank[None, :, :].astype(np.float32)) speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32)) hotword = np.loadtxt(hot_file) hotword_embeds = torch.from_numpy(hotword.astype(np.float32)) clas_scale = np.array([hotword[:,-1], ], dtype=np.float32) clas_scale = torch.from_numpy(clas_scale.astype(np.float32)) return (speech, speech_lengths, hotword_embeds, clas_scale) def get_input_names(self): return ['speech', 'speech_lengths', 'hotword_embeds', 'clas_scale'] def get_output_names(self): return ['logits', 'token_num'] def get_dynamic_axes(self): return { 'speech': { 0: 'batch_size', 1: 'feats_length' }, 'speech_lengths': { 0: 'batch_size', }, 'hotword_embeds': { 0: 'hotword_size', }, 'logits': { 0: 'batch_size', 1: 'logits_length' }, } class ContextualBiasEncoder(nn.Module): """ Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition https://arxiv.org/abs/2206.08317 """ def __init__( self, model, max_seq_len=512, feats_dim=560, context_dim=512, model_name='model', **kwargs, ): super().__init__() onnx = False if "onnx" in kwargs: onnx = kwargs["onnx"] self.bias_embed = model.bias_embed if hasattr(model, 'bias_embed') else None if hasattr(model, 'bias_encoder'): # if isinstance(model.bias_encoder, torch.nn.LSTM): logging.warning("enable bias encoder sampling and contextual training with bias_encoder") self.bias_encoder = model.bias_encoder else: logging.warning("Unsupport bias encoder type: {}".format(model.bias_encoder)) self.feats_dim = feats_dim self.context_dim = context_dim self.model_name = model_name + '_contextual_bias_encoder' def forward( self, hotword: torch.Tensor, # h0: torch.Tensor, # c0: torch.Tensor, ): # -1. bias encoder hotword = hotword.unsqueeze(0) hw_embed = self.bias_embed(hotword) hw_embed, (_, _) = self.bias_encoder(hw_embed) print('hw_embed:', hw_embed.shape) hotword_embed = hw_embed[0,-1,:].squeeze(0) print('hotword_embed:', hotword_embed.shape) return hotword_embed def get_dummy_inputs(self): hotword = torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).type(torch.int32) #hotword = torch.Tensor([1]).type(torch.int32) # h0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32) # c0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32) #return (hotword, h0, c0) return hotword def get_dummy_inputs_txt(self, hot_file: str = "/mnt/workspace/hotword.txt"): import numpy as np hotword = np.loadtxt(hot_file) hotword = torch.from_numpy(hotword.astype(np.int32)) # h0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32) # c0 = torch.zeros([1, 1, self.context_dim], dtype=torch.float32) # return (hotword, h0, c0) return hotword def get_input_names(self): #return ['hotword', 'h0', 'c0'] return ['hotword'] def get_output_names(self): #return ['hotword_embed', 'hn', 'cn'] return ['hotword_embed'] def get_dynamic_axes(self): return { 'hotword': { 0: 'hotword_length', }, }
4)新增funasr/export/models/modules/contextual_decoder_layer.py 脚本。脚本中两个类:ContextualDecoderLayer和 ContextualBiasDecoder 用于解码端热词所需的Attention计算。
#!/usr/bin/env python3 # -*- coding: utf-8 -*- import torch from torch import nn from typing import Tuple class ContextualDecoderLayer(nn.Module): def __init__( self, model ): super().__init__() self.self_attn = model.self_attn self.src_attn = model.src_attn self.feed_forward = model.feed_forward self.norm1 = model.norm1 self.norm2 = model.norm2 if hasattr(model, 'norm2') else None self.norm3 = model.norm3 if hasattr(model, 'norm3') else None self.size = model.size def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): if isinstance(tgt, Tuple): tgt, _ = tgt residual = tgt tgt = self.norm1(tgt) tgt = self.feed_forward(tgt) x = tgt if self.self_attn is not None: tgt = self.norm2(tgt) x, cache = self.self_attn(tgt, tgt_mask, cache=cache) x = residual + x x_self_attn = x residual = x if self.src_attn is not None: x = self.norm3(x) x = self.src_attn(x, memory, memory_mask) x_src_attn = x x = residual + x return x, tgt_mask, x_self_attn, x_src_attn class ContextualBiasDecoder(nn.Module): def __init__( self, model ): super().__init__() self.src_attn = model.src_attn self.norm3 = model.norm3 if hasattr(model, 'norm3') else None self.size = model.size def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): x = tgt if self.src_attn is not None: residual = x x = self.norm3(x) x = self.src_attn(x, memory, memory_mask) return x, tgt_mask, memory, memory_mask, cache
5)新增 funasr/export/models/decoder/contextual_decoder.py脚本,用于导出解码器。
import os import torch import torch.nn as nn from funasr.export.utils.torch_function import MakePadMask from funasr.export.utils.torch_function import sequence_mask from funasr.modules.attention import MultiHeadedAttentionSANMDecoder from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export from funasr.modules.attention import MultiHeadedAttentionCrossAtt from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export from funasr.models.decoder.contextual_decoder import ContextualDecoderLayer as ContextualDecoderLayer from funasr.models.decoder.contextual_decoder import ContextualBiasDecoder as ContextualBiasDecoder from funasr.export.models.modules.contextual_decoder_layer import ContextualDecoderLayer as ContextualDecoderLayer_export from funasr.export.models.modules.contextual_decoder_layer import ContextualBiasDecoder as ContextualBiasDecoder_export class ContextualParaformerSANMDecoder(nn.Module): def __init__(self, model, max_seq_len=512, model_name='decoder', onnx: bool = True,): super().__init__() # self.embed = model.embed #Embedding(model.embed, max_seq_len) self.model = model if onnx: self.make_pad_mask = MakePadMask(max_seq_len, flip=False) else: self.make_pad_mask = sequence_mask(max_seq_len, flip=False) for i, d in enumerate(self.model.decoders): if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward) if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn) if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt): d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn) self.model.decoders[i] = DecoderLayerSANM_export(d) if self.model.decoders2 is not None: for i, d in enumerate(self.model.decoders2): if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward) if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn) self.model.decoders2[i] = DecoderLayerSANM_export(d) for i, d in enumerate(self.model.decoders3): if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward) self.model.decoders3[i] = DecoderLayerSANM_export(d) if self.model.last_decoder is not None: if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM): print('export model.last_decoder, feed_forward ...') self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward) if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder): print('export model.last_decoder, self_attn ...') self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn) if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt): print('export model.last_decoder, src_attn ...') self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn) self.model.last_decoder = ContextualDecoderLayer_export(self.model.last_decoder) if self.model.bias_decoder is not None: if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt): print('export model.bias_decoder, src_attn ...') self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn) self.model.bias_decoder = ContextualBiasDecoder_export(self.model.bias_decoder) self.model.bias_output = self.model.bias_output if hasattr(model, 'bias_output') else None self.output_layer = model.output_layer self.after_norm = model.after_norm self.model_name = model_name def prepare_mask(self, mask): mask_3d_btd = mask[:, :, None] if len(mask.shape) == 2: mask_4d_bhlt = 1 - mask[:, None, None, :] elif len(mask.shape) == 3: mask_4d_bhlt = 1 - mask[:, None, :] mask_4d_bhlt = mask_4d_bhlt * -10000.0 return mask_3d_btd, mask_4d_bhlt def forward( self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, contextual_info: torch.Tensor, clas_scale: float = 1.0 ): tgt = ys_in_pad tgt_mask = self.make_pad_mask(ys_in_lens) tgt_mask, _ = self.prepare_mask(tgt_mask) # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] memory = hs_pad memory_mask = self.make_pad_mask(hlens) _, memory_mask = self.prepare_mask(memory_mask) # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] #print('tgt:',tgt.shape, 'tgt_mask:',tgt_mask.shape, 'memory:',memory.shape, 'memory_mask:',memory_mask.shape) x = tgt x, tgt_mask, memory, memory_mask, _ = self.model.decoders( x, tgt_mask, memory, memory_mask ) #print('\n>>> export last_decoder ....') _, _, x_self_attn, x_src_attn = self.model.last_decoder( x, tgt_mask, memory, memory_mask ) #print('x_self_attn:',x_self_attn.shape, 'x_src_attn:',x_src_attn.shape) # contextual paraformer related #contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0]) contextual_length = contextual_info.shape[1].repeat(hs_pad.shape[0]) contextual_mask = self.make_pad_mask(contextual_length) _, contextual_mask = self.prepare_mask(contextual_mask) #contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :] cx, tgt_mask, _, _, _ = self.model.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask) if self.model.bias_output is not None: x = torch.cat([x_src_attn, cx*clas_scale], dim=2) x = self.model.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D x = x_self_attn + x if self.model.decoders2 is not None: x, tgt_mask, memory, memory_mask, _ = self.model.decoders2( x, tgt_mask, memory, memory_mask ) x, tgt_mask, memory, memory_mask, _ = self.model.decoders3( x, tgt_mask, memory, memory_mask ) x = self.after_norm(x) x = self.output_layer(x) return x, ys_in_lens def get_dummy_inputs(self, enc_size): tgt = torch.LongTensor([0]).unsqueeze(0) memory = torch.randn(1, 100, enc_size) pre_acoustic_embeds = torch.randn(1, 1, enc_size) cache_num = len(self.model.decoders) + len(self.model.decoders2) cache = [ torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)) for _ in range(cache_num) ] if hasattr(self.model, 'last_decoder'): cache.append(torch.zeros((1, self.model.last_decoder.size, self.model.last_decoder.self_attn.kernel_size))) if hasattr(self.model, 'bias_decoder'): cache.append(torch.zeros((1, self.model.bias_decoder.size, self.model.bias_decoder.src_attn.kernel_size))) return (tgt, memory, pre_acoustic_embeds, cache) def is_optimizable(self): return True def get_input_names(self): cache_num = len(self.model.decoders) + len(self.model.decoders2) return_list = ['tgt', 'memory', 'pre_acoustic_embeds'] \ + ['cache_%d' % i for i in range(cache_num)] if hasattr(self.model, 'last_decoder'): return_list = return_list + ['cache_last'] if hasattr(self.model, 'bias_decoder'): return_list = return_list + ['cache_bias'] return return_list def get_output_names(self): cache_num = len(self.model.decoders) + len(self.model.decoders2) return_list = ['y'] \ + ['out_cache_%d' % i for i in range(cache_num)] if hasattr(self.model, 'last_decoder'): return_list = return_list + ['cache_last'] if hasattr(self.model, 'bias_decoder'): return_list = return_list + ['cache_bias'] return return_list def get_dynamic_axes(self): ret = { 'tgt': { 0: 'tgt_batch', 1: 'tgt_length' }, 'memory': { 0: 'memory_batch', 1: 'memory_length' }, 'pre_acoustic_embeds': { 0: 'acoustic_embeds_batch', 1: 'acoustic_embeds_length', } } cache_num = len(self.model.decoders) + len(self.model.decoders2) ret.update({ 'cache_%d' % d: { 0: 'cache_%d_batch' % d, 2: 'cache_%d_length' % d } for d in range(cache_num) }) ret.update({ 'contextual_info': { 0: 'contextual_info_batch', 1: 'contextual_info_length' } }) return ret def get_model_config(self, path): return { "dec_type": "XformerDecoder", "model_path": os.path.join(path, f'{self.model_name}.onnx'), "n_layers": len(self.model.decoders) + len(self.model.decoders2) + len(self.model.last_decoder), "odim": self.model.decoders[0].size }
6)新增 funasr/runtime/python/onnxruntime/funasr_onnx/contextual_paraformer_bin.py 用于热词推理。
# -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) import os.path from pathlib import Path from typing import List, Union, Tuple import copy import librosa import numpy as np import re from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, OrtInferSession, TokenIDConverter, get_logger, read_yaml) from .utils.postprocess_utils import sentence_postprocess from .utils.frontend import WavFrontend from .utils.timestamp_utils import time_stamp_lfr6_onnx logging = get_logger() class Contextual_Paraformer(): """ Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition https://arxiv.org/abs/2206.08317 """ def __init__(self, model_dir: Union[str, Path] = None, batch_size: int = 1, device_id: Union[str, int] = "-1", plot_timestamp_to: str = "", quantize: bool = False, intra_op_num_threads: int = 4, cache_dir: str = None ): if not Path(model_dir).exists(): from modelscope.hub.snapshot_download import snapshot_download try: model_dir = snapshot_download(model_dir, cache_dir=cache_dir) except: raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir) model_file = os.path.join(model_dir, 'model.onnx') if quantize: model_file = os.path.join(model_dir, 'model_quant.onnx') if not os.path.exists(model_file): print("model_file {} .onnx is not exist, begin to export onnx".format(model_file)) raise "model_file {} .onnx is not exist, begin to export onnx".format(model_file) context_model_file = os.path.join(model_dir, 'model_contextual_bias_encoder.onnx') if quantize: context_model_file = os.path.join(model_dir, 'model_contextual_bias_encoder_quant.onnx') if not os.path.exists(context_model_file): raise "model_file {} .onnx is not exist, begin to export onnx".format(context_model_file) config_file = os.path.join(model_dir, 'config.yaml') cmvn_file = os.path.join(model_dir, 'am.mvn') config = read_yaml(config_file) self.converter = TokenIDConverter(config['token_list']) self.tokenizer = CharTokenizer() self.frontend = WavFrontend( cmvn_file=cmvn_file, **config['frontend_conf'] ) self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) self.batch_size = batch_size self.plot_timestamp_to = plot_timestamp_to if "predictor_bias" in config['model_conf'].keys(): self.pred_bias = config['model_conf']['predictor_bias'] else: self.pred_bias = 0 # contextual hotword_file = os.path.join(model_dir, 'hotword.txt') with open(hotword_file, 'r') as f: hotword_list = f.readlines() #print('context_model_file:', context_model_file) #print('hotword_list:', hotword_list) self.context_ort_infer = OrtInferSession(context_model_file, device_id, intra_op_num_threads=intra_op_num_threads) self.hotword_embeds = self.extract_hotword_embed(hotword_list) print('self.hotword_embeds:', len(self.hotword_embeds), self.hotword_embeds[0].shape) def __call__(self, wav_content: Union[str, np.ndarray, List[str]], hotwords: List[str], clas_scale: np.ndarray, **kwargs) -> List: waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) waveform_nums = len(waveform_list) hotword_embeds = self.load_hotword(hotwords) clas_scale_r = np.array([1.0], dtype=np.float32) if not isinstance(clas_scale, type(clas_scale_r)): clas_scale = np.array([clas_scale], dtype=np.float32) print('clas_scale:', type(clas_scale)) print('hotword_embeds:', hotword_embeds.shape) asr_res = [] for beg_idx in range(0, waveform_nums, self.batch_size): end_idx = min(waveform_nums, beg_idx + self.batch_size) feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) print('feats:', feats.shape) print('feats_len:', feats_len) print('hotword_embeds:', hotword_embeds.shape) print('clas_scale:', clas_scale) try: outputs = self.infer(feats, feats_len, hotword_embeds, clas_scale) am_scores, valid_token_lens = outputs[0], outputs[1] print('am_scores:', am_scores.shape) if len(outputs) == 4: # for BiCifParaformer Inference us_alphas, us_peaks = outputs[2], outputs[3] else: us_alphas, us_peaks = None, None except ONNXRuntimeError: #logging.warning(traceback.format_exc()) logging.warning("input wav is silence or noise") preds = [''] else: preds = self.decode(am_scores, valid_token_lens) if us_peaks is None: for pred in preds: pred = sentence_postprocess(pred) asr_res.append({'preds': pred}) else: for pred, us_peaks_ in zip(preds, us_peaks): raw_tokens = pred timestamp, timestamp_raw = time_stamp_lfr6_onnx(us_peaks_, copy.copy(raw_tokens)) text_proc, timestamp_proc, _ = sentence_postprocess(raw_tokens, timestamp_raw) # logging.warning(timestamp) if len(self.plot_timestamp_to): self.plot_wave_timestamp(waveform_list[0], timestamp, self.plot_timestamp_to) asr_res.append({'preds': text_proc, 'timestamp': timestamp_proc, "raw_tokens": raw_tokens}) return asr_res def plot_wave_timestamp(self, wav, text_timestamp, dest): # TODO: Plot the wav and timestamp results with matplotlib import matplotlib matplotlib.use('Agg') matplotlib.rc("font", family='Alibaba PuHuiTi') # set it to a font that your system supports import matplotlib.pyplot as plt fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320) ax2 = ax1.twinx() ax2.set_ylim([0, 2.0]) # plot waveform ax1.set_ylim([-0.3, 0.3]) time = np.arange(wav.shape[0]) / 16000 ax1.plot(time, wav/wav.max()*0.3, color='gray', alpha=0.4) # plot lines and text for (char, start, end) in text_timestamp: ax1.vlines(start, -0.3, 0.3, ls='--') ax1.vlines(end, -0.3, 0.3, ls='--') x_adj = 0.045 if char != '<sil>' else 0.12 ax1.text((start + end) * 0.5 - x_adj, 0, char) # plt.legend() plotname = "{}/timestamp.png".format(dest) plt.savefig(plotname, bbox_inches='tight') def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: def load_wav(path: str) -> np.ndarray: waveform, _ = librosa.load(path, sr=fs) return waveform if isinstance(wav_content, np.ndarray): return [wav_content] if isinstance(wav_content, str): return [load_wav(wav_content)] if isinstance(wav_content, list): return [load_wav(path) for path in wav_content] raise TypeError( f'The type of {wav_content} is not in [str, np.ndarray, list]') def load_hotword(self, hotwords: List[str]) -> np.ndarray: hotwords_embeds_list = self.hotword_embeds new_hotwords = [] if len(hotwords) > 0: new_hotwords = self.extract_hotword_embed(hotwords) hotwords_embeds_list.extend(new_hotwords) hotwords_embeds = np.stack(hotwords_embeds_list, axis=0) return hotwords_embeds def extract_feat(self, waveform_list: List[np.ndarray] ) -> Tuple[np.ndarray, np.ndarray]: feats, feats_len = [], [] for waveform in waveform_list: speech, _ = self.frontend.fbank(waveform) feat, feat_len = self.frontend.lfr_cmvn(speech) feats.append(feat) feats_len.append(feat_len) feats = self.pad_feats(feats, np.max(feats_len)) feats_len = np.array(feats_len).astype(np.int32) return feats, feats_len def extract_hotword_embed(self, hotword_list): tokenids_list, hw_embed_list = [], [] for hotword in hotword_list: hotword = hotword.strip() tokenids = self.converter.tokens2ids(hotword) tokenids_list.append(tokenids_list) tokenids = np.array(tokenids, dtype=np.int32) hw_embed = self.context_ort_infer([tokenids])[0] hw_embed_list.append(hw_embed) return hw_embed_list @staticmethod def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: pad_width = ((0, max_feat_len - cur_len), (0, 0)) return np.pad(feat, pad_width, 'constant', constant_values=0) feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] feats = np.array(feat_res).astype(np.float32) return feats def infer(self, feats: np.ndarray, feats_len: np.ndarray, hotword_embed: np.ndarray, clas_scale: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: outputs = self.ort_infer([feats, feats_len, hotword_embed, clas_scale]) return outputs def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: return [self.decode_one(am_score, token_num) for am_score, token_num in zip(am_scores, token_nums)] def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]: yseq = am_score.argmax(axis=-1) score = am_score.max(axis=-1) score = np.sum(score, axis=-1) # pad with mask tokens to ensure compatibility with sos/eos tokens # asr_model.sos:1 asr_model.eos:2 yseq = np.array([1] + yseq.tolist() + [2]) hyp = Hypothesis(yseq=yseq, score=score) # remove sos/eos and get results last_pos = -1 token_int = hyp.yseq[1:last_pos].tolist() # remove blank symbol id, which is assumed to be 0 token_int = list(filter(lambda x: x not in (0, 2), token_int)) # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) token = token[:valid_token_num-self.pred_bias] # texts = sentence_postprocess(token) return token
We greatly appreciate your effort. Your code is excellent and completely accurate. The funasr official now supports the export and inference of paraformer hotword model in ONNX format. We welcome you to try it out and provide us with feedback. The python version: funasr-onnx 0.2.1 and the c++ version: funasr-runtime
thanks for your contribution again, it looks like we have done the same work at the same time, maybe next time you can raise pull request directly if there's any update obtained.
要想使得对FunASR源码的修改在本地生效,则需要通过源码方式安装FunASR,即:
1. 热词模型训练
首先将 damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch 预训练包下载到本地,通过 modelscope_args 加载预训练模型和相关配置文件,然后使用如下代码进行模型微调:
训练数据准备形式如下,包含:text文本数据和wav.scp音频数据,可以是中英文混合:
2. 热词模型导出onnx模型
目前版本缺少热词onnx导出相关代码,为此需要在相应位置处增加代码,具体如下: 注意:目前热词是在训练时候从训练数据中动态选择文本作为热词进行训练,训练代码用的是funasr/models/ e2e_asr_paraformer.py脚本里面的 ContextualParaformer 类;而模型onnx导出的时候实际用的是funasr/models/ e2e_asr_contextual_paraformer.py脚本里面的 NeatContextualParaformer类。为了正确导出onnx模型需要增加一些脚本。 热词版本导出的思路:将Paraformer主体模型和热词向量提取模型分隔成两个模型,这样做有两个好处:1)热词不用每次都重新提取热词向量,可以在初始化时候一次性将所有热词都提取出来,推理的时候如果有新的热词再额外增加,可以减少计算量,同时保持热词提取的灵活性;2)由于热词向量提取使用的是LSTM,batch_size只能是1,无法与Paraformer主体模型的batch_size匹配,所以需要单独把热词向量提取模块分离出来。
修改主要是:funasr/export文件夹下的文件,具体的有:
1)export_model.py 脚本修改,由于热词版本的模型会导出两个onnx模型,所以会修改 get_model 函数,使其返回的是一个列表,即:
2)funasr/export/models/init.py 脚本修改里面 get_model 函数: 其中,修改了 funasr/export/models/e2e_asr_paraformer.py 函数,脚本中增加了 NeatContextualParaformer 和 ContextualBiasEncoder 两个类函数来导出Paraformer主体模型和LSTM热词模型。
3)修改 funasr/export/models/e2e_asr_paraformer.py 脚本,脚本中增加NeatContextualParaformer 和 ContextualBiasEncoder 两个类来导出Paraformer主体模型和LSTM热词模型。
4)新增funasr/export/models/modules/contextual_decoder_layer.py 脚本。脚本中两个类:ContextualDecoderLayer和 ContextualBiasDecoder 用于解码端热词所需的Attention计算。
5)新增 funasr/export/models/decoder/contextual_decoder.py脚本,用于导出解码器。
6)新增 funasr/runtime/python/onnxruntime/funasr_onnx/contextual_paraformer_bin.py 用于热词推理。