jzlianglu / pykaldi2

Yet another speech toolkit based on Kaldi and PyTorch
MIT License
173 stars 33 forks source link

minibatches for LFMMI #22

Open yotam319 opened 4 years ago

yotam319 commented 4 years ago

Hi, I have a few suggestions on using LF-MMI: I noticed you are looping over the batch, creating the supervision and calculating the criterion. you can use MergeSupervision function I added to pykaldi in-order to create the supervision to the whole batch and run the criterion only once.

here is my collate function for the data-loader:

def supervision_collate(batch):
    """
    a collate function, for using supervision with dataloader
    """
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, container_abcs.Sequence):
        transposed = zip(*batch)
        return [supervision_collate(samples) for samples in transposed]
    elif isinstance(elem,kaldi.chain._chain_supervision.Supervision):
        if len(batch) == 1:
            return batch[0]
        return kaldi.chain.merge_supervison(batch)
    elif elem is None:
        return batch
    return torch.utils.data.dataloader.default_collate(batch)

To add the MergeSupervision I made a pull request to pykaldi (https://github.com/pykaldi/pykaldi/pull/182), but you can use my fork that already have the change (https://github.com/yotam319/pykaldi)

also, using phone_ali gives a small supervision, you should consider using lattices and phone_lattice_to_proto_supervision instead of alignment_to_proto_supervision.

and finally, you can save your supervision as bytes and read them again. here are the functions I used for doing this:

import kaldi
from kaldi import chain

def supervision_to_bytes(supervision):
    out_s = kaldi.base.io.stringstream()
    supervision.write(out_s,True)
    return out_s.to_bytes()

def supervision_from_supervision_bytes(supervision_bytes):
    in_s = kaldi.base.io.stringstream.from_str(supervision_bytes)
    supervision = kaldi.chain.Supervision()
    supervision.read(in_s,True)
    return supervision

def split_supervision(supervision, start, duration):
    sup_cut = kaldi.chain.SupervisionSplitter(supervision).get_frame_range(start,duration)
    sup_cut.fst = StdVectorFst(sup_cut.fst).rmepsilon()
    return sup_cut

def ali_phone_to_supervision_bytes(phones_durs,
                             opt, ctx_dep, trans_model):
    """
    input:
    phones_durs: list of phone*duration tuples
    opt: kaldi.chain.SupervisionOptions object
    ctx_dep: from kaldi.alignment.Aligner.read_tree("exp\chain\<ref_model>\tree")
    trans_model: from kaldi.alignment.Aligner.read_model("exp\chain\<ref_model>\0.trans_mdl")

    returns: byte representation of supervision
    """
    p_supervision = chain.alignment_to_proto_supervision_with_phones_durs(opt,phones_durs)
    supervision = chain.proto_supervision_to_supervision(ctx_dep,trans_model,p_supervision, opt.convert_to_pdfs)
    return supervision_to_bytes(supervision)

def lat_to_supervision_bytes(lat,phone_lat_mdl, phone_lat_opts,
                             supervision_opts, ctx_dep, trans_model):
    """
    input:
    lat: lattice
    phone_lat_mdl: final.mdl from the lat folder
    phone_lat_opts: PhoneAlignLatticeOptions object
    supervision_opts: kaldi.chain.SupervisionOptions object
    ctx_dep: from kaldi.alignment.Aligner.read_tree("exp\chain\<ref_model>\tree")
    trans_model: from kaldi.alignment.Aligner.read_model("exp\chain\<ref_model>\0.trans_mdl")

    returns: byte representation of supervision
    """
    (suc,phone_lat) = kaldi.lat.align.phone_align_lattice(lat,phone_lat_mdl, phone_lat_opts)
    assert suc
    phone_lat.topsort()
    phone_lat.topsort()
    p_supervision = chain.phone_lattice_to_proto_supervision(supervision_opts,phone_lat)
    supervision = chain.proto_supervision_to_supervision(ctx_dep,trans_model,p_supervision, supervision_opts.convert_to_pdfs)
    return supervision_to_bytes(supervision)

hope this helps :)

jzlianglu commented 4 years ago

Hi @yotam319, thanks a lot for your advice and code sample. Yes, what you said definitely makes sense. Previously, I only did a vanilla version of LF-MMI in the toolbox, and planed to revisit later to improve the efficiency. I noticed the code change in the pykaldi lib, but have not been able to squeeze my time to work on it. Our internal tools are not built on Kaldi, so I have very limited time to work on this toolkit. Will try to integrate your dataloader into the code base soon. Thanks again!