kendryte / nncase

Open deep learning compiler stack for Kendryte AI accelerators ✨
Apache License 2.0
752 stars 183 forks source link

Compiler for transformer shape_bucket failed #1270

Open xiangweizeng opened 1 week ago

xiangweizeng commented 1 week ago

转换一个Transformer 模型,采用 ShapeBucket时候失败,不采用ShapeBucket可以正常转换。 错误为: nncase 2.4/2.9版本结果均一致

Binary_229_Unary_104_Binary_228_Unary_103_Conv2D_71_Conv2D_69_Conv2D_70_Binary_231_Binary_230_Conv2D
Unhandled exception. System.AggregateException: One or more errors occurred. (Value cannot be null. (Parameter 'key'))
 ---> System.ArgumentNullException: Value cannot be null. (Parameter 'key')
   at System.Collections.Generic.Dictionary`2.TryInsert(TKey key, TValue value, InsertionBehavior behavior)
   at System.Linq.Enumerable.ToDictionary[TSource,TKey,TElement](IEnumerable`1 source, Func`2 keySelector, Func`2 elementSelector, IEqualityComparer`1 comparer)
   at Nncase.Passes.Rules.ShapeBucket.ShapeBucketHelper.MakeVarValuesForAllSegment(ShapeBucketOptions options, Int32 segmentCount, Boolean staticShape)
   at Nncase.Passes.Rules.ShapeBucket.RecordFusionShape.RunCoreAsync(BaseFunction main, RunPassContext context)
   at Nncase.Passes.Pass`2.RunAsync(TInput input, RunPassContext context)
   at Nncase.Passes.PassManager.FunctionPassGroup.Runner.RunAsync()
   at Nncase.Passes.PassManager.FunctionPassGroup.RunAsync(IRModule module)
   at Nncase.Passes.PassManager.RunAsync(IRModule module)
   at Nncase.Compiler.Compiler.RunPassAsync(Action`1 register, String name, IProgress`1 progress, CancellationToken token)
   at Nncase.Compiler.Compiler.CompileAsync(IProgress`1 progress, CancellationToken token)
   --- End of inner exception stack trace ---
   at System.Threading.Tasks.Task.Wait(Int32 millisecondsTimeout, CancellationToken cancellationToken)
   at Nncase.Compiler.Interop.CApi.CompilerCompile(IntPtr compilerHandle)

转换代码:

import os
import shutil

import nncase
import numpy as np
import onnx
import onnxsim

def generate_data_encoder(data_dir, input_shapes, data_count):
    data = [[]]
    for i in range(data_count):
        x_batch = np.fromfile(os.path.join(data_dir, 'X_{}.bin'.format(i)), dtype='int64').reshape(input_shapes[0])
        data[0].append(x_batch)
    return data

def parse_model_input_output(model_file, input_shapes_):
    onnx_model = onnx.load(model_file)
    input_all = [node.name for node in onnx_model.graph.input]
    input_initializer = [node.name for node in onnx_model.graph.initializer]
    input_names = list(set(input_all) - set(input_initializer))
    input_tensors = [
        node for node in onnx_model.graph.input if node.name in input_names]

    # input
    inputs = []
    for i, e in enumerate(input_tensors):
        onnx_type = e.type.tensor_type
        input_dict = {
            'name': e.name,
            'dtype': onnx.helper.tensor_dtype_to_np_dtype(onnx_type.elem_type),
            'shape': input_shapes_[i]
        }
        inputs.append(input_dict)
    return onnx_model, inputs

def onnx_simplify(model_file, dump_dir, input_shapes_):
    onnx_model, inputs = parse_model_input_output(model_file, input_shapes_)
    onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
    input_shapes = {}
    for input in inputs:
        input_shapes[input['name']] = input['shape']

    onnx_model, check = onnxsim.simplify(onnx_model, overwrite_input_shapes=input_shapes, )
    print(onnx.helper.printable_graph(onnx_model.graph))
    assert check, "Simplified ONNX model could not be validated"

    model_file = os.path.join(dump_dir, 'simplified.onnx')
    onnx.save_model(onnx_model, model_file)
    return model_file

def read_model_file(model_file):
    with open(model_file, 'rb') as f:
        model_content = f.read()
    return model_content

def encoder_tokmodel(onnx_model_path, kmodel_path, data_dir, ptq_option, input_shapes, data_count, tmp_path,
                     target='k230'):
    if not os.path.exists(tmp_path):
        os.makedirs(tmp_path)

    # onnx simplify
    model_file = onnx_simplify(onnx_model_path, tmp_path, input_shapes)

    # compile_options
    compile_options = nncase.CompileOptions()
    compile_options.target = target
    compile_options.preprocess = False
    compile_options.dump_ir = True
    compile_options.dump_asm = True
    compile_options.dump_dir = tmp_path

    compile_options.shape_bucket_enable = True
    compile_options.shape_bucket_range_info = {"seq_len": [1, 64]}
    compile_options.shape_bucket_segments_count = 64
    compile_options.shape_bucket_fix_var_map = {"batch_size": 1}

    # compiler
    compiler = nncase.Compiler(compile_options)

    # import
    model_content = read_model_file(model_file)
    import_options = nncase.ImportOptions()
    compiler.import_onnx(model_content, import_options)

    # ptq_options
    ptq_options = nncase.PTQTensorOptions()
    ptq_options.samples_count = data_count
    if ptq_option == 0:
        pass
    elif ptq_option == 1:
        ptq_options.calibrate_method = 'NoClip'
        ptq_options.w_quant_type = 'int16'
    elif ptq_option == 2:
        ptq_options.calibrate_method = 'NoClip'
        ptq_options.quant_type = 'int16'
    elif ptq_option == 3:
        ptq_options.w_quant_type = 'int16'
    elif ptq_option == 4:
        ptq_options.quant_type = 'int16'
    ptq_options.set_tensor_data(generate_data_encoder(data_dir, input_shapes, data_count))
    compiler.use_ptq(ptq_options)
    # compile
    compiler.compile()

    # model
    kmodel = compiler.gencode_tobytes()
    with open(kmodel_path, 'wb') as f:
        f.write(kmodel)
    if os.path.exists(tmp_path):
        shutil.rmtree(tmp_path)

if __name__ == "__main__":
    encoder_tokmodel(onnx_model_path="onnx/example.onnx",
                     kmodel_path="onnx/example.kmodel",
                     data_dir="generate_data",
                     ptq_option=0,
                     input_shapes=[[1, 64]],
                     data_count=30,
                     tmp_path='./tmp')

运行日志

Merge Binary_106_Unary_74_Binary_105_Unary_73
Binary_108_Binary_107
Merge Binary_106_Unary_74_Binary_105_Unary_73_Binary_108_Binary_107
Conv2D_16
Conv2D_17
Conv2D_18
Merge Reshape_233
Binary_111_Binary_109_Binary_110
Binary_114_Binary_112_Binary_113
268
Merge Reshape_233_Binary_111_Binary_109_Binary_110_Binary_114_Binary_112_Binary_113
Reshape_236_Concat_235
Merge Reshape_234
Reshape_233_Binary_111_Binary_109_Binary_110_Binary_114_Binary_112_Binary_113_Reshape_236_Concat_235
Merge Conv2D_19_MatMul_1
Binary_115
Merge Binary_117_Unary_76_Binary_116_Unary_75
Conv2D_22_Conv2D_20_Conv2D_21_Binary_119_Binary_118
Merge Binary_117_Unary_76_Binary_116_Unary_75_Conv2D_22_Conv2D_20_Conv2D_21_Binary_119_Binary_118
Binary_120
Merge Conv2D_19_MatMul_1_Binary_115
Binary_117_Unary_76_Binary_116_Unary_75_Conv2D_22_Conv2D_20_Conv2D_21_Binary_119_Binary_118_Binary_1
Merge Binary_122_Unary_78_Binary_121_Unary_77
Binary_124_Binary_123
Merge Binary_122_Unary_78_Binary_121_Unary_77_Binary_124_Binary_123
Conv2D_23
Conv2D_24
Conv2D_25
Merge Reshape_237
Binary_127_Binary_125_Binary_126
Binary_130_Binary_128_Binary_129
277
Merge Reshape_237_Binary_127_Binary_125_Binary_126_Binary_130_Binary_128_Binary_129
Reshape_240_Concat_239
Merge Reshape_238
Reshape_237_Binary_127_Binary_125_Binary_126_Binary_130_Binary_128_Binary_129_Reshape_240_Concat_239
Merge Conv2D_26_MatMul_3
Binary_131
Merge Binary_133_Unary_80_Binary_132_Unary_79
Conv2D_29_Conv2D_27_Conv2D_28_Binary_135_Binary_134
Merge Binary_133_Unary_80_Binary_132_Unary_79_Conv2D_29_Conv2D_27_Conv2D_28_Binary_135_Binary_134
Binary_136
Merge Conv2D_26_MatMul_3_Binary_131
Binary_133_Unary_80_Binary_132_Unary_79_Conv2D_29_Conv2D_27_Conv2D_28_Binary_135_Binary_134_Binary_1
Merge Binary_138_Unary_82_Binary_137_Unary_81
Binary_140_Binary_139
Merge Binary_138_Unary_82_Binary_137_Unary_81_Binary_140_Binary_139
Conv2D_30
Conv2D_31
Conv2D_32
Merge Reshape_241
Binary_143_Binary_141_Binary_142
Binary_146_Binary_144_Binary_145
286
Merge Reshape_241_Binary_143_Binary_141_Binary_142_Binary_146_Binary_144_Binary_145
Reshape_244_Concat_243
Merge Reshape_242
Reshape_241_Binary_143_Binary_141_Binary_142_Binary_146_Binary_144_Binary_145_Reshape_244_Concat_243
Merge Conv2D_33_MatMul_5
Binary_147
Merge Binary_149_Unary_84_Binary_148_Unary_83
Conv2D_36_Conv2D_34_Conv2D_35_Binary_151_Binary_150
Merge Binary_149_Unary_84_Binary_148_Unary_83_Conv2D_36_Conv2D_34_Conv2D_35_Binary_151_Binary_150
Binary_152
Merge Conv2D_33_MatMul_5_Binary_147
Binary_149_Unary_84_Binary_148_Unary_83_Conv2D_36_Conv2D_34_Conv2D_35_Binary_151_Binary_150_Binary_1
Merge Binary_154_Unary_86_Binary_153_Unary_85
Binary_156_Binary_155
Merge Binary_154_Unary_86_Binary_153_Unary_85_Binary_156_Binary_155
Conv2D_37
Conv2D_38
Conv2D_39
Merge Reshape_245
Binary_159_Binary_157_Binary_158
Binary_162_Binary_160_Binary_161
295
Merge Reshape_245_Binary_159_Binary_157_Binary_158_Binary_162_Binary_160_Binary_161
Reshape_248_Concat_247
Merge Reshape_246
Reshape_245_Binary_159_Binary_157_Binary_158_Binary_162_Binary_160_Binary_161_Reshape_248_Concat_247
Merge Conv2D_40_MatMul_7
Binary_163
Merge Binary_165_Unary_88_Binary_164_Unary_87
Conv2D_43_Conv2D_41_Conv2D_42_Binary_167_Binary_166
Merge Binary_165_Unary_88_Binary_164_Unary_87_Conv2D_43_Conv2D_41_Conv2D_42_Binary_167_Binary_166
Binary_168
Merge Conv2D_40_MatMul_7_Binary_163
Binary_165_Unary_88_Binary_164_Unary_87_Conv2D_43_Conv2D_41_Conv2D_42_Binary_167_Binary_166_Binary_1
Merge Binary_170_Unary_90_Binary_169_Unary_89
Binary_172_Binary_171
Merge Binary_170_Unary_90_Binary_169_Unary_89_Binary_172_Binary_171
Conv2D_44
Conv2D_45
Conv2D_46
Merge Reshape_249
Binary_175_Binary_173_Binary_174
Binary_178_Binary_176_Binary_177
304
Merge Reshape_249_Binary_175_Binary_173_Binary_174_Binary_178_Binary_176_Binary_177
Reshape_252_Concat_251
Merge Reshape_250
Reshape_249_Binary_175_Binary_173_Binary_174_Binary_178_Binary_176_Binary_177_Reshape_252_Concat_251
Merge Conv2D_47_MatMul_9
Binary_179
Merge Binary_181_Unary_92_Binary_180_Unary_91
Conv2D_50_Conv2D_48_Conv2D_49_Binary_183_Binary_182
Merge Binary_181_Unary_92_Binary_180_Unary_91_Conv2D_50_Conv2D_48_Conv2D_49_Binary_183_Binary_182
Binary_184
Merge Conv2D_47_MatMul_9_Binary_179
Binary_181_Unary_92_Binary_180_Unary_91_Conv2D_50_Conv2D_48_Conv2D_49_Binary_183_Binary_182_Binary_1
Merge Binary_186_Unary_94_Binary_185_Unary_93
Binary_188_Binary_187
Merge Binary_186_Unary_94_Binary_185_Unary_93_Binary_188_Binary_187
Conv2D_51
Conv2D_52
Conv2D_53
Merge Reshape_253
Binary_191_Binary_189_Binary_190
Binary_194_Binary_192_Binary_193
313
Merge Reshape_253_Binary_191_Binary_189_Binary_190_Binary_194_Binary_192_Binary_193
Reshape_256_Concat_255
Merge Reshape_254
Reshape_253_Binary_191_Binary_189_Binary_190_Binary_194_Binary_192_Binary_193_Reshape_256_Concat_255
Merge Conv2D_54_MatMul_11
Binary_195
Merge Binary_197_Unary_96_Binary_196_Unary_95
Conv2D_57_Conv2D_55_Conv2D_56_Binary_199_Binary_198
Merge Binary_197_Unary_96_Binary_196_Unary_95_Conv2D_57_Conv2D_55_Conv2D_56_Binary_199_Binary_198
Binary_200
Merge Conv2D_54_MatMul_11_Binary_195
Binary_197_Unary_96_Binary_196_Unary_95_Conv2D_57_Conv2D_55_Conv2D_56_Binary_199_Binary_198_Binary_2
Merge Binary_202_Unary_98_Binary_201_Unary_97
Binary_204_Binary_203
Merge Binary_202_Unary_98_Binary_201_Unary_97_Binary_204_Binary_203
Conv2D_58
Conv2D_59
Conv2D_60
Merge Reshape_257
Binary_207_Binary_205_Binary_206
Binary_210_Binary_208_Binary_209
322
Merge Reshape_257_Binary_207_Binary_205_Binary_206_Binary_210_Binary_208_Binary_209
Reshape_260_Concat_259
Merge Reshape_258
Reshape_257_Binary_207_Binary_205_Binary_206_Binary_210_Binary_208_Binary_209_Reshape_260_Concat_259
Merge Conv2D_61_MatMul_13
Binary_211
Merge Binary_213_Unary_100_Binary_212_Unary_99
Conv2D_64_Conv2D_62_Conv2D_63_Binary_215_Binary_214
Merge Binary_213_Unary_100_Binary_212_Unary_99_Conv2D_64_Conv2D_62_Conv2D_63_Binary_215_Binary_214
Binary_216
Merge Conv2D_61_MatMul_13_Binary_211
Binary_213_Unary_100_Binary_212_Unary_99_Conv2D_64_Conv2D_62_Conv2D_63_Binary_215_Binary_214_Binary_
Merge Binary_218_Unary_102_Binary_217_Unary_101
Binary_220_Binary_219
Merge Binary_218_Unary_102_Binary_217_Unary_101_Binary_220_Binary_219
Conv2D_65
Conv2D_66
Conv2D_67
Merge Reshape_261
Binary_223_Binary_221_Binary_222
Binary_226_Binary_224_Binary_225
331
Merge Reshape_261_Binary_223_Binary_221_Binary_222_Binary_226_Binary_224_Binary_225
Reshape_264_Concat_263
Merge Reshape_262
Reshape_261_Binary_223_Binary_221_Binary_222_Binary_226_Binary_224_Binary_225_Reshape_264_Concat_263
Merge Conv2D_68_MatMul_15
Binary_227
Merge Binary_229_Unary_104_Binary_228_Unary_103
Conv2D_71_Conv2D_69_Conv2D_70_Binary_231_Binary_230
Merge Binary_229_Unary_104_Binary_228_Unary_103_Conv2D_71_Conv2D_69_Conv2D_70_Binary_231_Binary_230
Conv2D_72
Merge Conv2D_68_MatMul_15_Binary_227
Binary_229_Unary_104_Binary_228_Unary_103_Conv2D_71_Conv2D_69_Conv2D_70_Binary_231_Binary_230_Conv2D
Unhandled exception. System.AggregateException: One or more errors occurred. (Value cannot be null. (Parameter 'key'))
 ---> System.ArgumentNullException: Value cannot be null. (Parameter 'key')
   at System.Collections.Generic.Dictionary`2.TryInsert(TKey key, TValue value, InsertionBehavior behavior)
   at System.Linq.Enumerable.ToDictionary[TSource,TKey,TElement](IEnumerable`1 source, Func`2 keySelector, Func`2 elementSelector, IEqualityComparer`1 comparer)
   at Nncase.Passes.Rules.ShapeBucket.ShapeBucketHelper.MakeVarValuesForAllSegment(ShapeBucketOptions options, Int32 segmentCount, Boolean staticShape)
   at Nncase.Passes.Rules.ShapeBucket.RecordFusionShape.RunCoreAsync(BaseFunction main, RunPassContext context)
   at Nncase.Passes.Pass`2.RunAsync(TInput input, RunPassContext context)
   at Nncase.Passes.PassManager.FunctionPassGroup.Runner.RunAsync()
   at Nncase.Passes.PassManager.FunctionPassGroup.RunAsync(IRModule module)
   at Nncase.Passes.PassManager.RunAsync(IRModule module)
   at Nncase.Compiler.Compiler.RunPassAsync(Action`1 register, String name, IProgress`1 progress, CancellationToken token)
   at Nncase.Compiler.Compiler.CompileAsync(IProgress`1 progress, CancellationToken token)
   --- End of inner exception stack trace ---
   at System.Threading.Tasks.Task.Wait(Int32 millisecondsTimeout, CancellationToken cancellationToken)
   at Nncase.Compiler.Interop.CApi.CompilerCompile(IntPtr compilerHandle)