modelscope / FunASR

A Fundamental End-to-End Speech Recognition Toolkit and Open Source SOTA Pretrained Models, Supporting Speech Recognition, Voice Activity Detection, Text Post-processing etc.
https://www.funasr.com
Other
5.94k stars 644 forks source link

FunASR标点模型finetune与onnx模型导出 #893

Open ROAD2018 opened 1 year ago

ROAD2018 commented 1 year ago

注意,本教程是完全基于FunASR进行标点模型微调与onnx模型导出,不涉及modelscope。

1. 标点模型训练

标点模型训练与微调借鉴 FunASR/egs/aishell2这个例子进行,具体如下:

1) 下载标点预训练模型文件夹 punc_ct-transformer_zh-cn-common-vocab272727-pytorch 到本地 FunASR/egs/aishell2 目录下。

2)FunASR/egs/aishell2 目录下新建 tokenize_text.py 文件用于进行文本和标点处理,主要是根据预训练模型punc_ct-transformer_zh-cn-common-vocab272727-pytorch 文件夹中的 punc.yaml 配置文件对输入文本进行文字和标点提取。可以应用 WeTextProcessing 工具包进行文本正则化,也可以利用FunASR自带的正则化脚本进行处理。

tokenize_text.py 脚本如下:

#!/usr/bin/env python3
import argparse
from collections import Counter
import logging
from pathlib import Path
import sys
from typing import List
from typing import Optional, Union, Dict
import re, os
import yaml

from tn.chinese.normalizer import Normalizer
normalizer = Normalizer(cache_dir='WeTextProcessing/tn')

from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none

def read_yaml(yaml_path: Union[str, Path]) -> Dict:
    if not Path(yaml_path).exists():
        raise FileExistsError(f'The {yaml_path} does not exist.')

    with open(str(yaml_path), 'rb') as f:
        data = yaml.load(f, Loader=yaml.Loader)
    return data

def split_to_mini_sentence(words: list, word_limit: int = 20):
    assert word_limit > 1
    if len(words) <= word_limit:
        return [words]
    sentences = []
    length = len(words)
    sentence_len = length // word_limit
    for i in range(sentence_len):
        sentences.append(words[i * word_limit:(i + 1) * word_limit])
    if length % word_limit > 0:
        sentences.append(words[sentence_len * word_limit:])
    return sentences

def normalize_punc(punc_list):
    puncs = []
    if len(punc_list) > 0:
        for punc in punc_list:
            if punc == ",":
                punc = ","
            elif punc == "?":
                punc = "?"
            elif punc == "!":
                punc = "。"
            elif punc == "!":
                punc = "。"
            elif punc == ":":
                punc = ","
            elif punc == ":":
                punc = ","
            elif punc == ".":
                punc = "。"
            puncs.append(punc)
    return puncs

def code_mix_split_words(text: str, punc_list: list):
    words = []
    puncs = []
    segs = text.split()
    for idx, seg in enumerate(segs):
        # There is no space in seg.
        current_word = ""
        for c in seg:
            if c in punc_list:
                if len(puncs) > 0:
                    puncs.pop() 
                puncs.append(c)
            elif len(c.encode()) == 1:
                # This is an ASCII char.
                current_word += c
            else:
                # This is a Chinese char.
                if len(current_word) > 0:
                    words.append(current_word)
                    puncs.append('_')
                    current_word = ""
                words.append(c)
                puncs.append('_')
        if len(current_word) > 0:
            words.append(current_word)
            puncs.append('_')
        if len(puncs) > 0 and idx < len(segs) - 1:
            puncs.pop() 
            puncs.append(',')
    if text[-1] not in punc_list:
        if len(puncs) < 5:
            if len(puncs) > 0:
                puncs.pop()
            puncs.append(',')
        else:
            if len(puncs) > 0:
                puncs.pop()
            puncs.append('。')

    return words, puncs

def is_chinese(char):
    if '\u4e00' <= char <= '\u9fff':
        return 1
    else:
        return 0

def is_english(char):
    if '\u0041' <= char <= '\u007a':
        return 1
    else:
        return 0

def is_number(char):
    if '\u0030' <= char <= '\u0039':
        return 1
    else:
        return 0

def alpha_lower(text):
    chars = [char.lower() if char.isalpha() else char for char in text]
    sent = ''.join(chars)
    return sent

def remove_special_symbol(text):
    symbols = ['(',')','【','】','[',']','(',')','<','>','《','》', '\n', '\t', '\r', '*', '“', '”', '"', '"', '~', '~','-', '-','_','+','*','×','&','@','#']
    sent = text
    for s in symbols:
        sent = sent.replace(s,'')
    return sent

def tokenize(
    input: str,
    output_dir: str,
    stats_dir: str,
    config_file: str,
    token_file: str,
    punc_file: str,
    token_type: str,
    data_type: str,
    delimiter: Optional[str],
    non_linguistic_symbols: Optional[str],
    log_level: str,
    remove_non_linguistic_symbols: bool,
    only_calc_shape: bool,
):

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if input == "-":
        fin = sys.stdin
    else:
        fin = Path(input).open("r", encoding="utf-8")

    corpus_text_file = ''
    corpus_punc_file = ''
    if output_dir == "-":
        otout = sys.stdout
        opout = sys.stdout
    else:
        if data_type == 'train':
            corpus_text_file = output_dir + '/train/text'
            corpus_punc_file = output_dir + '/train/punc'
        elif data_type == 'valid':
            corpus_text_file = output_dir + '/valid/text'
            corpus_punc_file = output_dir + '/valid/punc'
        elif data_type == 'test':
            corpus_text_file = output_dir + '/test/text'
            corpus_punc_file = output_dir + '/test/punc'
        else:
            print('>>> !!! Not supported dataset type for corpus, should be one of train,valid,test !!!')
            exit()
        ot = Path(corpus_text_file)
        ot.parent.mkdir(parents=True, exist_ok=True)
        otout = ot.open("w", encoding="utf-8")
        op = Path(corpus_punc_file)
        op.parent.mkdir(parents=True, exist_ok=True)
        opout = op.open("w", encoding="utf-8")

    if not os.path.exists(token_file):
        if token_file == "-":
            tf = sys.stdout
        else:
            tf = Path(token_file)
            tf.parent.mkdir(parents=True, exist_ok=True)
            tfout = tf.open("w", encoding="utf-8")
    if not os.path.exists(punc_file):    
        if punc_file == "-":
            pf = sys.stdout
        else:
            pf = Path(punc_file)
            pf.parent.mkdir(parents=True, exist_ok=True)
            pfout = pf.open("w", encoding="utf-8")

    text_shape_file = ''
    punc_shape_file = ''
    if stats_dir == "-":
        stout = sys.stdout
        spout = sys.stdout
    else:
        if data_type == 'train':
            text_shape_file = stats_dir + '/train/text_shape'
            punc_shape_file = stats_dir + '/train/punc_shape'
        elif data_type == 'valid':
            text_shape_file = stats_dir + '/valid/text_shape'
            punc_shape_file = stats_dir + '/valid/punc_shape'
        elif data_type == 'test':
            text_shape_file = stats_dir + '/test/text_shape'
            punc_shape_file = stats_dir + '/test/punc_shape'
        else:
            print('>>> !!! Not supported dataset type for shape file, should be one of train,valid,test !!!')
            exit()
        st = Path(text_shape_file)
        st.parent.mkdir(parents=True, exist_ok=True)
        stout = st.open("w", encoding="utf-8")
        sp = Path(punc_shape_file)
        sp.parent.mkdir(parents=True, exist_ok=True)
        spout = sp.open("w", encoding="utf-8")

    unk_counter = Counter()
    all_counter = Counter()

    config = read_yaml(config_file)
    token_list = config['token_list']
    punc_list = config['punc_list']

    if not os.path.exists(token_file):
        for token in token_list:
            tfout.write(token + "\n")
        print('\n>>> token_list generated successfully !!!')
    else:
        print('\n>>> token_list already exists, bypass !!!')
    if not os.path.exists(punc_file):
        for punc in punc_list:
            pfout.write(punc + "\n")
        print('\n>>> punc_list generated successfully !!!')
    else:
        print('\n>>> punc_list already exists, bypass !!!')

    if only_calc_shape:
        for idx, line in enumerate(fin):
            line = line.rstrip()
            name, text = line.split(' ', maxsplit=1)
            tokens = text.split(' ')
            shape_strings = name + ' ' + str(len(tokens))
            opout.write(shape_strings + "\n")
            print(shape_strings)
        print('\n!!! Only shape file generated !!!\n')
        os.system('cp {} {}'.format(input, output))
        print('\n!!! File {} copied to {} !!!\n'.format(input, output))
        return

    ### get tokens
    punc_list_ext = ["unk", "?", ",", ".", ",", "?", "。", ":", ":"]
    for idx, line in enumerate(fin):
        idx = str('{:020d}'.format(idx))
        line = line.rstrip()
        line = remove_special_symbol(line)
        print(line)
        ##line = normalizer.normalize(line)

        words, puncs = code_mix_split_words(line, punc_list_ext)
        puncs = normalize_punc(puncs)

        words_strings = idx + ' ' + delimiter.join(words)
        words_shape_strings = idx + ' ' + str(len(words))

        puncs_strings = idx + ' ' + delimiter.join(puncs)
        puncs_shape_strings = idx + ' ' + str(len(puncs))

        print(words_strings + '\n' + puncs_strings)
        print(words_shape_strings + '\n' + puncs_shape_strings)

        otout.write(words_strings + "\n")
        opout.write(words_shape_strings + "\n")

        stout.write(puncs_strings + "\n")
        spout.write(puncs_shape_strings + "\n")

def get_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Tokenize texts",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--log_level",
        type=lambda x: x.upper(),
        default="INFO",
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )

    parser.add_argument(
        "--token_type",
        "-t",
        default="char",
        choices=["char", "bpe", "word", "phn"],
        help="Token type",
    )
    parser.add_argument(
        "--data_type",
        "-d",
        default="train",
        choices=["train", "valid", "test"],
        help="dataset type",
    )
    parser.add_argument(
        "--config_file", "-c", required=True, help="Path to config file"
    )
    parser.add_argument(
        "--token_file", "-k", required=True, help="Path to token file"
    )
    parser.add_argument(
        "--punc_file", "-p", required=True, help="Path to punc file"
    )

    parser.add_argument(
            "--input", "-i", required=True, help="Input text. - indicates sys.stdin"
        )
    parser.add_argument(
            "--output_dir", "-o", required=True, help="Output text. - indicates sys.stdout"
    )
    parser.add_argument(
            "--stats_dir", "-s", required=True, help="lm stats dir"
    )
    parser.add_argument("--delimiter", "-l", default=' ', type=str, help="The delimiter")
    parser.add_argument(
        "--non_linguistic_symbols",
        type=str_or_none,
        help="non_linguistic_symbols file path",
    )
    parser.add_argument(
        "--remove_non_linguistic_symbols",
        type=str2bool,
        default=False,
        help="Remove non-language-symbols from tokens",
    )
    parser.add_argument(
        "--only_calc_shape",
        type=str2bool,
        default=False,
        help="Whether to only calculate shape",
    )

    return parser

def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    tokenize(**kwargs)

if __name__ == "__main__":
    main()

待处理文本是领域相关的带标点的文本数据,处理前后的文本和标点如下:

原始文本:快过年了就囤一些放在家里,吃大鱼大肉多了,正好有它来去去腻。
text: 快 过 年 了 就 囤 一 些 放 在 家 里 吃 大 鱼 大 肉 多 了 正 好 有 它 来 去 去 腻。
punc: _ _ _ _ _ _ _ _ _ _ _ , _ _ _ _ _ _ , _ _ _ _ _ _ _ 。

3)基于run.sh修改得到finetune.sh脚本。

#!/usr/bin/env bash

. ./path.sh || exit 1;

# machines configuration
CUDA_VISIBLE_DEVICES="0,1"
gpu_num=2
count=1
train_cmd=utils/run.pl
infer_cmd=utils/run.pl

# general configuration
lang=zh_en
nlsyms_txt=none            # Non-linguistic symbol list if existing.
cleaner=none               # Text cleaner.
g2p=none         # g2p method (needed if token_type=phn).
punc_fold_length=150         # fold_length for LM training.
word_vocab_size=10000 # Size of word vocabulary.
token_type=word
delimiter=' '
split_with_space=true

token_list=./tokens.txt
punc_list=./puncs.txt

nj=10
## path to AISHELL2 trans
train_text=./corpus/train.txt
dev_text=./corpus/valid.txt
test_text=./corpus/test.txt

corpus_output_dir=./dataset/scripts/content
shape_stats_dir=./dataset/scripts/shape

train_text_tokenized="${corpus_output_dir}"/train/text
train_punc_tokenized="${corpus_output_dir}"/train/punc

dev_text_tokenized="${corpus_output_dir}"/valid/text
dev_punc_tokenized="${corpus_output_dir}"/valid/punc

test_text_tokenized="${corpus_output_dir}"/test/text
test_punc_tokenized="${corpus_output_dir}"/test/punc

# train_data_path_and_name_and_type=${punc_train_text},text,text
# train_shape_file=
# valid_data_path_and_name_and_type=${punc_dev_text},text,text
# valid_shape_file=

text_train_data_path_and_name_and_type=${train_text_tokenized},text,text
text_train_shape_file=
punc_train_data_path_and_name_and_type=${train_punc_tokenized},punc,text
punc_train_shape_file=

text_valid_data_path_and_name_and_type=${dev_text_tokenized},text,text
text_valid_shape_file=
punc_valid_data_path_and_name_and_type=${dev_punc_tokenized},punc,text
punc_valid_shape_file=

punc_config=conf/train_punc.yaml
exp_dir=./data
tag=exp1
model_dir="baseline_$(basename "${punc_config}" .yaml)_${lang}_${token_type}_${tag}"
punc_exp=${exp_dir}/exp/${model_dir}

inference_punc=valid.loss.ave.pb       # Language model path for decoding.

stage=0
stop_stage=3

. utils/parse_options.sh || exit 1;

# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail

min() {
  local a b
  a=$1
  for b in "$@"; do
      if [ "${b}" -le "${a}" ]; then
          a="${b}"
      fi
  done
  echo "${a}"
}

# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, e.g., gpuid_list=2,3, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')

mkdir -p ${exp_dir}/exp/${model_dir}
blank="<blank>" # CTC blank symbole
sos="<s>"       # sos symbole
eos="</s>"      # eos symbole
oov="<unk>"     # Out of vocabulary symbol.
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    if [ "${token_type}" = char ] || [ "${token_type}" = word ]; then
        echo "Stage 0: Generate character level token_list from ${train_text}"

        # The first symbol in token_list must be "<blank>":
        # 0 is reserved for CTC-blank for ASR and also used as ignore-index in the other task
        python tokenize_text.py  \
            --data_type "train" \
            --token_type "${token_type}" \
            --config_file "${punc_config}" \
            --token_file "${token_list}" \
            --punc_file "${punc_list}" \
            --delimiter "${delimiter}"  \
            --input "${train_text}" \
            --output_dir "${corpus_output_dir}" \
            --stats_dir "${punc_stats_dir}" \
            --non_linguistic_symbols "${nlsyms_txt}" \
            --only_calc_shape false

        python tokenize_text.py  \
            --data_type "valid" \
            --token_type "${token_type}" \
            --config_file "${punc_config}" \
            --token_file "${token_list}" \
            --punc_file "${punc_list}" \
            --delimiter "${delimiter}"  \
            --input "${dev_text}" \
            --output_dir "${corpus_output_dir}" \
            --stats_dir "${punc_stats_dir}" \
            --non_linguistic_symbols "${nlsyms_txt}" \
            --only_calc_shape false
    else
        echo "Error: not supported --token_type '${token_type}'"
        exit 2
    fi
fi

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    echo "stage 1: Data preparation"

    # 1. Split the key file
    _logdir="${exp_dir}/exp/${model_dir}/log"
    mkdir -p "${_logdir}"
    # Get the minimum number among ${nj} and the number lines of input files
    _nj=$(min "${nj}" "$(<${train_text_tokenized} wc -l)" "$(<${dev_text_tokenized} wc -l)")

    key_file="${train_text_tokenized}"
    split_scps=""
    for n in $(seq ${_nj}); do
        split_scps+=" ${_logdir}/train.text.${n}.scp"
    done
    # shellcheck disable=SC2086
    utils/split_scp.pl "${key_file}" ${split_scps}

    key_file="${train_punc_tokenized}"
    split_scps=""
    for n in $(seq ${_nj}); do
        split_scps+=" ${_logdir}/train.punc.${n}.scp"
    done
    # shellcheck disable=SC2086
    utils/split_scp.pl "${key_file}" ${split_scps}

    key_file="${dev_text_tokenized}"
    split_scps=""
    for n in $(seq ${_nj}); do
        split_scps+=" ${_logdir}/dev.text.${n}.scp"
    done
    # shellcheck disable=SC2086
    utils/split_scp.pl "${key_file}" ${split_scps}

    key_file="${dev_punc_tokenized}"
    split_scps=""
    for n in $(seq ${_nj}); do
        split_scps+=" ${_logdir}/dev.punc.${n}.scp"
    done
    # shellcheck disable=SC2086
    utils/split_scp.pl "${key_file}" ${split_scps}

    # 2. Submit jobs
    # Append the num-tokens at the last dimensions. This is used for batch-bins count
    <"${shape_stats_dir}/train/text_shape" \
        awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \
        >"${shape_stats_dir}/train/text_shape.${token_type}"
    <"${shape_stats_dir}/train/punc_shape" \
        awk -v N="$(<${punc_list} wc -l)" '{ print $0 "," N }' \
        >"${shape_stats_dir}/train/punc_shape.${token_type}"

    <"${shape_stats_dir}/valid/text_shape" \
        awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \
        >"${shape_stats_dir}/valid/text_shape.${token_type}"
    <"${shape_stats_dir}/valid/text_shape" \
        awk -v N="$(<${punc_list} wc -l)" '{ print $0 "," N }' \
        >"${shape_stats_dir}/valid/punc_shape.${token_type}"

    train_text_shape_file=${shape_stats_dir}/train/text_shape.${token_type}
    train_punc_shape_file=${shape_stats_dir}/train/punc_shape.${token_type}

    valid_text_shape_file=${shape_stats_dir}/valid/text_shape.${token_type}
    valid_punc_shape_file=${shape_stats_dir}/valid/punc_shape.${token_type}

fi

# Training Stage
world_size=$gpu_num  # run on one machine
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    echo "stage 2: Training"
    token_list="${token_list}"
    punc_list="${punc_list}"
    punc_stats_dir=${exp_dir}/exp/${model_dir}
    train_text_shape_file=${shape_stats_dir}/train/text_shape.${token_type}
    train_punc_shape_file=${shape_stats_dir}/train/punc_shape.${token_type}
    valid_text_shape_file=${shape_stats_dir}/valid/text_shape.${token_type}
    valid_punc_shape_file=${shape_stats_dir}/valid/punc_shape.${token_type}

    mkdir -p ${punc_exp}
    mkdir -p ${punc_exp}/log
    INIT_FILE=${punc_exp}/ddp_init
    if [ -f $INIT_FILE ];then
        rm -f $INIT_FILE
    fi
    init_method=file://$(readlink -f $INIT_FILE)
    echo "$0: init method is $init_method"
    for ((i = 0; i < $gpu_num; ++i)); do
        {
            rank=$i
            local_rank=$i
            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
            python  ../../../funasr/bin/punc_train.py \
                --gpu_id ${gpu_id} \
                --use_preprocessor true \
                --token_type "${token_type}" \
                --token_list "${token_list}" \
                --punc_list "${punc_list}" \
                --non_linguistic_symbols "${nlsyms_txt}" \
                --cleaner "${cleaner}" \
                --split_with_space "${split_with_space}" \
                --train_data_path_and_name_and_type "${text_train_data_path_and_name_and_type}" \
                --train_data_path_and_name_and_type "${punc_train_data_path_and_name_and_type}" \
                --train_shape_file "${train_text_shape_file}" \
                --train_shape_file "${train_punc_shape_file}" \
                --valid_data_path_and_name_and_type "${text_valid_data_path_and_name_and_type}" \
                --valid_data_path_and_name_and_type "${punc_valid_data_path_and_name_and_type}" \
                --valid_shape_file "${valid_text_shape_file}" \
                --valid_shape_file "${valid_punc_shape_file}" \
                --fold_length "${punc_fold_length}" \
                --resume true \
                --output_dir "${punc_exp}" \
                --config ${punc_config} \
                --ngpu ${gpu_num} \
                --num_worker_count ${count} \
                --multiprocessing_distributed true \
                --dist_init_method ${init_method} \
                --dist_world_size ${world_size} \
                --dist_rank ${rank} \
                --local_rank ${local_rank}
        } & 
      done
      wait
fi

# Testing Stage
gpu_num=1
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    echo "Stage 3: Calc perplexity: ${punc_test_text}"
    python tokenize_text.py  \
            --data_type "test" \
            --token_type "${token_type}" \
            --config_file "${punc_config}" \
            --token_file "${token_list}" \
            --punc_file "${punc_list}" \
            --delimiter "${delimiter}"  \
            --input "${test_text}" \
            --output_dir "${corpus_output_dir}" \
            --stats_dir "${punc_stats_dir}" \
            --non_linguistic_symbols "${nlsyms_txt}" \
            --only_calc_shape false

    python ../../../funasr/bin/inference_punc.py \
        --output_dir "${punc_exp}" \
        --ngpu 1 \
        --gpuid_list 0 \
        --batch_size 1 \
        --train_config "${punc_exp}"/config.yaml \
        --mode "transformer" \
        --model_file "${punc_exp}/${inference_punc}" \
        --data_path_and_name_and_type "${punc_test_text_tokenized},text,text" \
        --num_workers 1 \
        --split_with_space False 
fi

4)修改 ../../../funasr/datasets/preprocessor.py文件以适配训练所需的文本和标点处理方式。

修改preprocessor.py 文件中的 PuncTrainTokenizerCommonPreprocessor 类,如下:

class PuncTrainTokenizerCommonPreprocessor(AbsPreprocessor):
    def __init__(
            self,
            train: bool,
            token_type: Union[str, List[str]] = [None],
            token_list: List[Union[Path, str, Iterable[str]]] = [None],
            bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
            text_cleaner: Collection[str] = None,
            g2p_type: str = None,
            unk_symbol: str = "<unk>",
            space_symbol: str = "<space>",
            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
            delimiter: str = None,
            rir_scp: str = None,
            rir_apply_prob: float = 1.0,
            noise_scp: str = None,
            noise_apply_prob: float = 1.0,
            noise_db_range: str = "3_10",
            speech_volume_normalize: float = None,
            speech_name: str = "speech",
            text_name: List[str] = ["text"],
            vad_name: str = "vad_indexes",
    ):
        # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
        super().__init__(train)
        train=train,
        self.token_type=token_type,
        self.token_list=token_list,
        self.bpemodel=bpemodel,
        self.text_cleaner=text_cleaner,
        self.g2p_type=g2p_type,
        self.unk_symbol=unk_symbol,
        self.space_symbol=space_symbol,
        self.non_linguistic_symbols=non_linguistic_symbols,
        self.delimiter=delimiter,
        self.speech_name=speech_name,
        self.text_name=text_name,
        self.rir_scp=rir_scp,
        self.rir_apply_prob=rir_apply_prob,
        self.noise_scp=noise_scp,
        self.noise_apply_prob=noise_apply_prob,
        self.noise_db_range=noise_db_range,
        self.speech_volume_normalize=speech_volume_normalize,

        assert (
                len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
        ), "token_type, token_list, bpemodel, or processing text_name mismatched"
        self.num_tokenizer = len(token_type)
        self.tokenizer = []
        self.token_id_converter = []

        for i in range(self.num_tokenizer):
            if token_type[i] is not None:
                if token_list[i] is None:
                    raise ValueError("token_list is required if token_type is not None")

                self.tokenizer.append(
                    build_tokenizer(
                        token_type=token_type[i],
                        bpemodel=bpemodel[i],
                        delimiter=delimiter,
                        space_symbol=space_symbol,
                        non_linguistic_symbols=non_linguistic_symbols,
                        g2p_type=g2p_type,
                    )
                )
                self.token_id_converter.append(
                    TokenIDConverter(
                        token_list=token_list[i],
                        unk_symbol=unk_symbol,
                    )
                )
            else:
                self.tokenizer.append(None)
                self.token_id_converter.append(None)

        self.text_cleaner = TextCleaner(text_cleaner)
        self.text_name = text_name  # override the text_name from CommonPreprocessor
        self.vad_name = vad_name

    def _text_process(
            self, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        for i in range(self.num_tokenizer):
            text_name = self.text_name[i]
            #import pdb; pdb.set_trace()
            if text_name in data and self.tokenizer[i] is not None:
                text = data[text_name]
                text = self.text_cleaner(text)
                tokens = self.tokenizer[i].text2tokens(text)
                if "vad:" in tokens[-1]:
                    vad = tokens[-1][4:]
                    tokens = tokens[:-1]
                    if len(vad) == 0:
                        vad = -1
                    else:
                        vad = int(vad)
                    data[self.vad_name] = np.array([vad], dtype=np.int64)
                text_ints = self.token_id_converter[i].tokens2ids(tokens)
                data[text_name] = np.array(text_ints, dtype=np.int64)
        return data

    def __call__(
            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:

        data = self._text_process(data)
        return data

2. 标点模型onnx导出

由于仓库已经写好标点模型的导出方式,暂不需要额外脚本实现模型导出,只需官方命令即可: python -m funasr.export.export_model --model-name punc_model_dir --export-dir ./export --type torch --quantize false

LRY1994 commented 9 months ago

CONTROLLABLE TIME-DELAY TRANSFORMER FOR REAL-TIME PUNCTUATION PREDICTION AND DISFLUENCY DETECTION 提出联合建模标点预测和 disfluency detection,请问代码里面哪里有disfluency detection的部分?

184653090 commented 4 months ago

请问这个教程是基于哪个分支的?

lancelee98 commented 3 months ago

请问 conf/train_punc.yaml 这个里面的内容是什么? 为什么按照你的步骤会报这个错呀? @ROAD2018 image

otoTree commented 1 month ago

Error while finding module specification for 'funasr.export.export_model' (ModuleNotFoundError: No module named 'funasr.export')

wowfingerlicker commented 3 weeks ago

tokens.txt是怎么得到的呀?