jiazhihao / TASO

The Tensor Algebra SuperOptimizer for Deep Learning
Apache License 2.0
692 stars 90 forks source link

Error when optimizing BERT-SQuAD #54

Open StepyHan opened 4 years ago

StepyHan commented 4 years ago

Hi, Dear TASO authors. I I tried to optimize bert with onnx model in https://github.com/onnx/models/tree/master/text/machine_comprehension/bert-squad. An error came out: python3: /home/workspace/TASO/taso/src/core/reshape.cc:39: taso::Tensor* taso::Graph::reshape(taso::TensorHandle, const std::vector<int>&): Assertion input_size == 1' failed.

The error seems caused by def _reshape(op, graph, tensors, initializer), when shape is not in initializer but is produced by other ops, then the shape list would be empty, so that reshape fail.

def _reshape(op, graph, tensors, initializer):
    inputs = _get_inputs(op, graph, tensors, initializer)
    assert len(inputs) == 2
    shape = list()
    for data in initializer:
        if data.name == op.input[1]:
            shape = list()
            if data.int64_data != []:
                for dim in data.int64_data:
                    shape.append(dim)
            elif data.raw_data and data.raw_data != []:
                shape_in_array = numpy_helper.to_array(data)
                for dim in shape_in_array:
                    shape.append(dim)
    outputs = graph.reshape(inputs[0], tuple(shape))
    return outputs

Look forward to your commit, thanks!