XiaoMi / kaldi-onnx

Kaldi model converter to ONNX
Apache License 2.0
237 stars 57 forks source link

Problem with compiling offset feedback to fast-lstmp layer #46

Open ZuoyunZheng opened 2 years ago

ZuoyunZheng commented 2 years ago

Let's say i have a nnet3 that looks something like this:

input-node name=input dim=40 component-node name=tdnn1.affine component=lda.tdnn1.affine input=Append(Offset(input, -1), input, Offset(input, 1)) component-node name=tdnn1.relu component=tdnn1.relu input=tdnn1.affine component-node name=tdnn1.batchnorm component=tdnn1.batchnorm input=tdnn1.relu component-node name=tdnn2.affine component=tdnn2.affine input=Append(Offset(tdnn1.batchnorm, -1), tdnn1.batchnorm, Offset(tdnn1.batchnorm, 1)) component-node name=tdnn2.relu component=tdnn2.relu input=tdnn2.affine component-node name=tdnn2.batchnorm component=tdnn2.batchnorm input=tdnn2.relu dim-range-node name=lstm1.c input-node=lstm1.lstm_nonlin dim-offset=0 dim=512 dim-range-node name=lstm1.m input-node=lstm1.lstm_nonlin dim-offset=512 dim=512 component-node name=lstm1.rp component=lstm1.W_rp input=lstm1.m dim-range-node name=lstm1.r input-node=lstm1.rp dim-offset=0 dim=128 component-node name=output.affine component=output.affine input=lstm1.rp component-node name=output.log-softmax component=output.log-softmax input=output.affine output-node name=output input=Offset(output.log-softmax, 5) objective=linear component-node name=lstm1.W_all component=lstm1.W_all input=Append(tdnn2.batchnorm, IfDefined(Offset(lstm1.r, -3))) component-node name=lstm1.lstm_nonlin component=lstm1.lstm_nonlin input=Append(lstm1.W_all, IfDefined(Offset(lstm1.c, -3)))

When converting this model, the graph compilation gets stuck in reorder_nodes method in converter/graph.py. The reason is for example, the IfDefined node lstm1.c-3 of the offset feedback cannot be reordered since there's a looped dependency with the lstm1.lstm_nonlin node. Since i don't really understand the logic of code line 407 in converter/graph.py, i was wondering if anyone can help me understand this problem better.

Thanks in advance.

ZuoyunZheng commented 2 years ago

Ad-hoc fix was:

  1. reorder the function calls in prepare_graph():
    
        self.get_consts()
        # check inputs and outputs
        self.fetch_inputs_outputs()
        #self.reorder_nodes()
        self.update_nodes_by_name()
        self.add_cache_nodes()
        self.fetch_inputs_outputs()
        # add PadContext for 'input' if the model has left or right context
        # fuse statistics extraction and pooling, lstm cell
        self.fuse_nodes()
        self.add_cache_nodes()
        self.fetch_inputs_outputs()
        self.fetch_model_inputs_outputs()
        self.reorder_nodes(False)
2. changes in reorder_nodes():
    ...
    while len(nodes_need_check) > 0:
        for node in list(nodes_need_check):
            ...
            if set(depend_inputs) <= set(checked_names) \
                    or (node.type == KaldiOpType.IfDefined.name
                            and ifdefine and 'IfDefined' in node.inputs[-1]):
        ...
3. changes in add_cache_nodes():
    cache_nodes = list()
    for node in self._nodes:
        input = node.inputs[-1]
        if node.type == KaldiOpType.IfDefined.name and \
                not input.endswith('.IfDefined'):
            _LOG.info(f"Appending cache nodes {node.inputs[-1]}")
            if input in self._nodes_by_name:
                input_node = self._nodes_by_name[input]
                cache_node_name = input_node.name + '.Cache'
                cache_inputs = [input_node.name]
                cache_node = make_node(cache_node_name,
                                       KaldiOpType.Identity.name,
                                       cache_inputs,
                                       [cache_node_name])
                cache_nodes.append(cache_node)
                node.inputs = node.inputs[:-1]
                node.inputs.append(cache_node_name + '.IfDefined')
            else:
                cache_node_name = input + '.Cache'
                cache_inputs = [input]
                cache_node = make_node(cache_node_name,
                                       KaldiOpType.Identity.name,
                                       cache_inputs,
                                       [cache_node_name])
                node.inputs = node.inputs[:-1]
                node.inputs.append(cache_node_name + '.IfDefined')
                cache_nodes.append(cache_node)
    if len(cache_nodes) > 0:
        self._nodes.extend(cache_nodes)
        self.fetch_inputs_outputs()
        self.reorder_nodes()