aalto-speech / AaltoASR

Aalto Automatic Speech Recognition tools
Other
85 stars 37 forks source link

openfst_decoder #20

Open vsiivola opened 8 years ago

vsiivola commented 8 years ago

Openfst decoder code. Not very documented, but you basically build the search network pretty much the same way that you'd do for kaldi (except use aalto acu models). You'll probably want to make the code compilation conditional if you take this pull request. Surprisingly, needs openfst headers and lib.

ammansik commented 8 years ago

You have any scripts available for converting AaltoASR lexicon together with phoneme set and acoustic model into openfst network (HCLG)?

vsiivola commented 8 years ago

There's a full c++ implementation for grammar networks in production, but that uses some code that I cannot share. It would take a while to separate the shareable parts from the non-shareable ones.

Here are some scattered notes, that I haven't tested/used in a while...

This first notes should take care of triphones and triphone to monophone mappings (C and H). The grammar generation fo BNF format is unfortunately code that I cannot share, but arpa2fst from kaldi should do it for arpa models (never tested). I also have code for generating the G from HTK lattice format, if that is of interest. Creating L should be straightforward, there is an example at the end of this message (it also assumes the L input symbols are iso-8859-1 and output utf-8 and can be found/added to openfst symbol tables m_acu_osyms/m_latin_osyms). I would need to find the time to document this properly, but right now it does seem problematic to find the time.

####################################

# Pregenerate the context model
(cd ../../aku/scripts; ./hmms2trinet.pl /home/stt/Models/asr_models/am/phone/lsphone_speechdat_elisa_spuh-small-ml_20.ph .) | /opt/mitfsm/bin/fst_optimize -A - C.fst
python3 /home/vsiivola/Code/acu_train/acutrainer/htk2fst.py --to_openfst C.syms --mitfst_source > C-openfst.fst < C.fst
/opt/openfst/bin/fstcompile --isymbols=C.syms --osymbols=C.syms.out C-openfst.fst C-bin.fst

# Pregenerate the triphone model
../decoder/src/hmm2fsm /home/stt/Models/asr_models/am/phone/lsphone_speechdat_elisa_spuh-small-ml_20.ph tritmp.fst && /opt/mitfsm/bin/fst_closure tritmp.fst tri.fst
python3 /home/vsiivola/Code/acu_train/acutrainer/htk2fst.py --to_openfst tri.syms --mitfst_source > tri-openfst.fst < tri.fst
/opt/openfst/bin/fstcompile --isymbols=tri.syms --osymbols=C.syms tri-openfst.fst tri-bin.fst

# The following could be used directly instead of using the separate pieces as created above
# compose the context and triphone models, convert to openfst
# FIXME: Cannot remember if there was a problem with this direct approach?
/opt/mitfsm/bin/fst_compose tri.fst C.fst - | /opt/mitfsm/bin/fst_determinize - - | /opt/mitfsm/bin/fst_minimize - - | /opt/mitfsm/bin/fst_optimize -A - - | python3 /home/vsiivola/Code/acu_train/acutrainer/htk2fst.py --to_openfst triC.syms --mitfst_source > triC.fst

# Compile to usable openfst bin representation
/opt/openfst/bin/fstcompile --isymbols=triC.syms --osymbols=triC.syms.out --keep_isymbols --keep_osymbols triC.fst triC-bin.fst

####################################

The lexicon can can generated along the lines of

std::shared_ptr<fst::StdVectorFst> LatticeOperations::create_monophone_lexfst(
    const std::string &lexstring, std::set<std::string> *dedup_set) {
    std::shared_ptr<fst::StdVectorFst> monolex(new fst::StdVectorFst);
    monolex->SetInputSymbols(m_acu_osyms_ptr.get());
    monolex->SetOutputSymbols(m_latin_osyms.get());

    monolex->AddState();
    monolex->AddState();
    monolex->SetStart(0);
    monolex->SetFinal(1, 0.0);

    // Create lex from string if supplied, else
    // just use the output symbols
    std::vector<std::pair<std::string, std::string> > tokens(
            lexstring.size()?lexstring2tokenvector(lexstring):
            symboltable2tokenvector(m_latin_osyms.get()));

    for (auto kv: tokens) {
            const std::string &isym(kv.first);
            const std::string &osym(kv.second);

            if ( osym == "<eps>" ) {
                    continue;
            }

            if (dedup_set) {
                    if (dedup_set->find(isym) != dedup_set->end()) {
                            std::cerr << "Warning, duplicate lex entry ("
                                      << osym << " "<< isym
                                      << "), discarding." << std::endl;
                            continue;
                    }
                    dedup_set->insert(isym);
            }

            //std::cerr << "Got isym '" << isym << "' osym '" << osym << "'" << std::endl;

            // Store the mappings, since we need to expand the output symbols
            // if we need to figure out which path was taken
            // e.g. osym "1" expands to ["1#yks", "1#yksi"]
            std::string pvarstring("!" + isym);
            if (isym == "_" || isym == "__" || isym == "<eps>") {
                    pvarstring = "";
            }
            m_token2pvariants[osym].push_back(pvarstring);

            // Build the network
            std::string posym(osym + pvarstring);
            int64 osymidx = m_latin_osyms->Find(posym);
            if (osymidx == -1) {
                    //std::cout << "Adding latin osym " << posym <<
                    //        " at create_monophone_lexfst()" << std::endl;
                    osymidx = m_latin_osyms->AddSymbol(posym);
            }

            if (isym.size() == 1 || osym == "__"
                    || osym == "<s>" || osym == "</s>") {
                    int64 symidx = m_acu_osyms_ptr->Find(osym);
                    if (symidx == -1) {
                            if ( osym == "<s>" || osym =="</s>") {
                                    symidx = m_acu_osyms_ptr->Find("<eps>");
                            } else {
                                    std::cout << "Adding isym " << isym <<
                                            " at create monophone_lexfst()" << std::endl;
                                    //symidx = m_acu_osyms_ptr->AddSymbol(isym);
                                    throw LatticeException("Unexpected phoneme symbol " + isym);
                            }
                    }

                    monolex->AddArc(
                            0, fst::StdArc(symidx, osymidx, 0.0, 1));
                    continue;
            }

            int64 prev_state = 0;
            for (int i=0; i<isym.size(); i++) {
                    char c[2];
                    c[1] = '\0';
                    c[0] = isym[i];
                    std::string subsym(c[0]==' '?"__":c);
                    int64 symidx = m_acu_osyms_ptr->Find(subsym);
                    if (symidx == -1) {
                            symidx = m_acu_osyms_ptr->AddSymbol(subsym);
                    }
                    int64 new_state;
                    if (i<isym.size()-1) {
                            new_state = monolex->AddState();
                    } else {
                            new_state = 1;
                    }
                    int64 osym = prev_state?m_eps_idx:osymidx;

                    monolex->AddArc(
                            prev_state, fst::StdArc(symidx, osym, 0.0, new_state));
                    prev_state = new_state;
            }

    }
    return monolex;
}

LatticeOperations::lexstring2tokenvector(const std::string &lexstring) {
    // Split string to vector of pairs<insymbol, outsymbol>
    std::vector<std::pair<std::string, std::string> > lexvec;
    std::istringstream iss(lexstring);
    std::string line;

    // For ICU utf-8 -> iso-8859-15 conversion
    UErrorCode status = U_ZERO_ERROR;
    UConverter *conv = ucnv_open("utf-8", &status);
    if ( status != U_ZERO_ERROR ) {
            throw LatticeException("ICU ucnv_open failed with error code " + std::to_string(status));
    }

    while (iss.good()) {
            std::getline(iss, line);
            if (line.size()==0) {
                    continue;
            }
            boost::trim(line);
            char *pOutputBuffer = (char *) malloc(sizeof(char) * (line.size() *1.5 ) );

            size_t nOutputSizeTmp = ucnv_convert("iso-8859-15", "utf-8", pOutputBuffer, 10240, line.c_str(), line.size(), &status);
            if ( status != U_ZERO_ERROR ) {
                    throw LatticeException("ICU ucnv_convert failed with error code " + std::to_string(status));
            }
            line = std::string(pOutputBuffer);
            free(pOutputBuffer);
            auto first_separator = line.find("\t");
            if (first_separator == std::string::npos) {
                    first_separator = line.find(" ");
                    if (first_separator==std::string::npos) {
                            throw LatticeException("Lex parse error: "+line);
                    }
            }
            std::string firsts(line.substr(first_separator+1));
            std::string seconds(line.substr(0, first_separator));
            //std::cerr << "Create pair " << firsts << " && " << seconds << std::endl;
            lexvec.push_back(std::make_pair(firsts, seconds));
    }
    ucnv_close(conv);

    if (!m_metasyms_added) {
            // FIXME : A separate fst to union for these !
            //std::cerr << "Add metasyms" << std::endl;
            lexvec.push_back(std::make_pair("__", "__"));
            m_metasyms_added = true;
    }
    return lexvec;
}
ammansik commented 8 years ago

Thank you very much for the detailed help. I'll try to get started with this. Btw, is the script htk2fst.py available somewhere? I guess it's not the same as the Perl script htk2fst.pl found in the aku scripts dir?

vsiivola commented 8 years ago

htk2fst.py:

#/usr/bin/env python3
import codecs
import io
import math
import operator
import re
import sys

def htk2mitfst(in_iter, lm_scale=1.0, ac_scale=None,
               store_lm_prob_out_labels=False):
    yield "#FSTBasic MaxPlus"
    for l in in_iter:
        m = re.match(r"start=(\d+)", l)
        if m:
            yield "I %s" % m.group(1)

        m = re.search(r"\send=(\d+)", l)
        if m:
            final = m.group(1)
            yield "F %s" % final
            break

    log10 = math.log(10)
    for l in in_iter:
        m = re.match(\
          r"J=\d+\s+S=(\d+)\s+E=(\d+)\s+W=(\S+)\s+a=(\S+)\s+l=(\S+)\s*$", l)
        if m:
            w = re.sub("!NULL", ",", m.group(3))
            logprob = 0
            unscaled_lm_logprob = float(m.group(5)) * log10
            if lm_scale:
                logprob += lm_scale * unscaled_lm_logprob
            if ac_scale:
                logprob += ac_scale * float(m.group(4)) * log10
            if store_lm_prob_out_labels:
                yield "T %s %s %s %f %g" % \
                    (m.group(1), m.group(2), w, unscaled_lm_logprob, logprob)
            else:
                yield "T %s %s %s %s %g" % \
                    (m.group(1), m.group(2), w, w, logprob)

    yield "%s" % final

def mitfst2openfst(in_iter, store_lm_prob_out_labels=False, eps_symbol="."):
    symbol_dict = {"<eps>": 0}
    if store_lm_prob_out_labels:
        out_symbol_dict = {"<eps>": 0}
        owidx = 1
    widx = 1
    retstring = ""
    final_states = []
    for l in in_iter:
        if l.startswith("F "):
            final_states.append(l.split()[1])
        # Recode labels
        if l.startswith("T "):
            fields = l.split()
            t = fields[0]
            from_state = fields[1]
            to_state = fields[2]
            label1 = fields[3] if len(fields) > 3 else eps_symbol
            label2 = fields[4] if len(fields) > 4 else eps_symbol
            if label1 == ",":
                label1 = eps_symbol
            if label2 == ",":
                label2 = eps_symbol
            prob = fields[5] if len(fields) == 6 else 0.0

            prob = max(0.0, -float(prob))
            if not label1 in symbol_dict:
                symbol_dict[label1] = widx
                widx += 1
            if store_lm_prob_out_labels:
                out_symbol_dict[label2] = owidx
                owidx += 1
                retstring += "%s %s %s %s %g\n" % \
                    (from_state, to_state, label1, label2, prob)
            else:
                assert label1 == label2
                retstring += "%s %s %s %s %g\n" % \
                    (from_state, to_state, label2, label2, prob)
            continue
    for final in final_states:
        retstring += "%s\n" % final
    if store_lm_prob_out_labels:
        return retstring, symbol_dict, out_symbol_dict
    return retstring, symbol_dict

def openfst2mitfst(in_iter, rescale=None):
    yield "#FSTBasic MaxPlus"

    final = None
    start = None
    resarray = []
    for l in in_iter:
        valtuple = l.split()
        if len(valtuple) == 1:
            assert not final
            final = valtuple[0]
            continue
        if len(valtuple) == 4:
            from_state, to_state, sym, sym2 = valtuple
            prob = 0.0
        elif len(valtuple) == 5:
            from_state, to_state, sym, sym2, prob = valtuple
            if rescale:
                prob = float(prob)/rescale
            else:
                prob = float(sym2)
        if not start:
            yield "I %s" % from_state
            start = True
        resarray.append((from_state, to_state, sym, prob))
    yield "F %s" % final
    for from_state, to_state, sym, prob in resarray:
        yield "T %s %s %s %s %g" %(from_state, to_state, sym, sym, prob)
    yield "%s" % final

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(
        description="Convert between fst representations",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(\
        '--to_openfst', metavar="FILENAME",
        default=False, help="convert to openfst, write symbol map to FILENAME")
    parser.add_argument(\
        '--from_openfst', action="store_true",
        default=False, help="convert from openfst to mitfst")
    parser.add_argument(\
        '-l', '--lm_scale', metavar='FLOAT', type=float,
        help='scale lm logprobs by float, only valid for "\
        "conversion from htk, default 1.0', default=1.0)
    parser.add_argument(\
        '-a', '--ac_scale', metavar='FLOAT', type=float,
        help='scale acu logprobs by float, only valid "\
        "for conversion from htk, default 0.0', default=0.0)
    parser.add_argument(
        '-r', '--rescale', metavar='FLOAT', type=float,
        help='rescale opefst probs, if not set get the probs "\
        "from fst out labels directly', default=0.0)
    parser.add_argument(
        '-m', '--mitfst_source', action="store_true",
        help='Input is mitfsm instead of htk',
        default=False)

    args = parser.parse_args()

    sys.stdout = codecs.getwriter("iso-8859-15")(sys.stdout.detach())

    lm_scale = args.lm_scale if args.lm_scale != 0.0 else None
    ac_scale = args.ac_scale if args.ac_scale != 0.0 else None

    if args.from_openfst:
        print("\n".join(openfst2mitfst(\
            io.TextIOWrapper(sys.stdin.buffer, encoding='iso-8859-15'),
            args.rescale if args.rescale != 0.0 else None)))
        sys.exit(0)

    if not args.to_openfst:
        print("\n".join(htk2mitfst(\
            io.TextIOWrapper(sys.stdin.buffer, encoding='iso-8859-15'),
            lm_scale, ac_scale)))
        sys.exit(0)

    inwrap = io.TextIOWrapper(sys.stdin.buffer, encoding='iso-8859-15')
    if not args.mitfst_source:
        inwrap = htk2mitfst(inwrap, lm_scale, ac_scale, True)
    retstring, symbol_dict, out_symbol_dict = mitfst2openfst(
        inwrap, True, "<eps>" if args.mitfst_source else ",")

    sym_fh = open(args.to_openfst, "w", encoding="iso-8859-1")
    for k, v in sorted(symbol_dict.items(), key=operator.itemgetter(1)):
        sym_fh.write("%s %d\n" % (k, v))
    sym_fh.close()

    sym_fh = open(args.to_openfst+".out", "w", encoding="iso-8859-1")
    for k, v in sorted(out_symbol_dict.items(), key=operator.itemgetter(1)):
        sym_fh.write("%s %d\n" % (k, v))
    sym_fh.close()

    print(retstring)

###########################################
# NOTES:
# python3 ../acutrainer/htk2fst.py --to_openfst syms.openfst -l 30 -a 1 < Pasi_Ryhanen_näkyvyys.lattice_tool > Pasi_Ryhanen_näkyvyys.lattice_tool-openfst
# /opt/openfst/bin/fstcompile --isymbols=syms.openfst --osymbols=syms.openfst.out Pasi_Ryhanen_näkyvyys.lattice_tool-openfst Pasi_Ryhanen_näkyvyys.lattice_tool-openfst-bin
# /opt/openfst/bin/fstprune --weight=1.0 Pasi_Ryhanen_näkyvyys.lattice_tool-openfst-bin Pasi_Ryhanen_näkyvyys.lattice_tool-openfst-bin-pruned
# /opt/openfst/bin/fstprint --isymbols=syms.openfst --osymbols=syms.openfst.out Pasi_Ryhanen_näkyvyys.lattice_tool-openfst-bin-pruned Pasi_Ryhanen_näkyvyys.lattice_tool-openfst-pruned
# python3 ../acutrainer/htk2fst.py --from_openfst < Pasi_Ryhanen_näkyvyys.lattice_tool-openfst-pruned > Pasi_Ryhanen_näkyvyys.lattice_tool-openfst-pruned-mitfst
#
# /opt/openfst/bin/fstdraw --isymbols=syms.openfst --osymbols=syms.openfst Pasi_Ryhanen_näkyvyys.lattice_tool-openfst-bin binary.dot
# dot -Tps -Gcharset=latin1 binary.dot  > binary.ps
vsiivola commented 8 years ago

I noticed on mistake in this pull request. When running without open phone loop for confidence, it prunes tokens in nonfinal states in the wrong place. Fixed code committed....