kendryte / nncase

Open deep learning compiler stack for Kendryte AI accelerators ✨
Apache License 2.0
745 stars 181 forks source link

自己训练转化的模型在k210上部署报错 #1254

Open 17168220609 opened 1 week ago

17168220609 commented 1 week ago

这是我的模型架构


import torch
import torch.nn as nn
import onnx
import os

class CharacterClassifier(nn.Module):
    def __init__(self):
        super(CharacterClassifier, self).__init__()

        # Add BatchNorm and more modern activation functions
        self.features = nn.Sequential(
            # First convolutional block
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Second convolutional block
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Third convolutional block
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Fourth convolutional block
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Add global average pooling
            nn.AdaptiveAvgPool2d((2, 4))  # Ensure fixed output size
        )

        # Classifier part uses a better structure
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),  # Add Dropout before the first fully connected layer
            nn.Linear(128 * 2 * 4, 256),  # Reduce the number of neurons in the first fully connected layer
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),  # Add BatchNorm
            nn.Dropout(0.5),
            nn.Linear(256, 13)
        )

        # Weight initialization
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def export_model_to_onnx(model, input_shape=(1, 3, 32, 64), path="character_classifier.onnx"):
    """Export the model to ONNX format"""
    dummy_input = torch.randn(input_shape)
    torch.onnx.export(model, 
                     dummy_input,
                     path,
                     input_names=['input'],
                     output_names=['output'],
                     dynamic_axes=None)

    # Verify the model
    onnx_model = onnx.load(path)
    onnx.checker.check_model(onnx_model)

    # Print model size
    model_size = os.path.getsize(path) / (1024 * 1024)  # Convert to MB
    print(f"ONNX model size: {model_size:.2f} MB")

if __name__ == "__main__":
    # Create and export model
    model = CharacterClassifier()
    export_model_to_onnx(model)
下面是我从onnx转化为kmodel的代码
```python
import os
import onnxsim
import onnx
import nncase

def parse_model_input_output(model_file):
    onnx_model = onnx.load(model_file)
    input_all = [node.name for node in onnx_model.graph.input]
    input_initializer = [node.name for node in onnx_model.graph.initializer]
    input_names = list(set(input_all) - set(input_initializer))
    input_tensors = [node for node in onnx_model.graph.input if node.name in input_names]

    # input
    inputs = []
    for _, e in enumerate(input_tensors):
        onnx_type = e.type.tensor_type
        input_dict = {}
        input_dict['name'] = e.name
        input_dict['dtype'] = onnx.helper.tensor_dtype_to_np_dtype(onnx_type.elem_type)
        input_dict['shape'] = [(i.dim_value if i.dim_value != 0 else d) for i, d in zip(
            onnx_type.shape.dim, [1, 3, 32, 64])]
        inputs.append(input_dict)

    return onnx_model, inputs

def onnx_simplify(model_file):
    onnx_model, inputs = parse_model_input_output(model_file)
    onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
    input_shapes = {}
    for input in inputs:
        input_shapes[input['name']] = input['shape']

    onnx_model, check = onnxsim.simplify(onnx_model, test_input_shapes=input_shapes)
    assert check, "Simplified ONNX model could not be validated"

    model_file = os.path.join(os.path.dirname(model_file), 'simplified.onnx')
    onnx.save_model(onnx_model, model_file)
    return model_file

def read_model_file(model_file):
    with open(model_file, 'rb') as f:
        model_content = f.read()
    return model_content

def main():
    #python onnx2kmodel.py
    model_file = "character_classifier.onnx"
    target = 'k210'

    # onnx simplify
    model_file = onnx_simplify(model_file)

    # compile_options
    compile_options = nncase.CompileOptions()
    compile_options.target = target
    compile_options.dump_ir = True
    compile_options.dump_asm = True
    #compile_options.dump_dir = 'tmp'
    compile_options.input_type = 'float32'  # or 'uint8' 'int8'
    compile_options.output_type = 'float32'  # or 'uint8' 'int8'. Only work in PTQ
    # compiler
    compiler = nncase.Compiler(compile_options)

    # import_options
    import_options = nncase.ImportOptions()

    # import
    model_content = read_model_file(model_file)
    compiler.import_onnx(model_content, import_options)

    # compile
    compiler.compile()

    # kmodel
    kmodel = compiler.gencode_tobytes()
    with open('1.kmodel', 'wb') as f:
        f.write(kmodel)

if __name__ == '__main__':
    main()
然后部署到k210,他哥的就报错了!
![NOSDM~Q $E~QB(2SERCDJIP](https://github.com/user-attachments/assets/e4214398-ad0c-41f5-80a4-dafc56148e66)
转化并没有问题
![image](https://github.com/user-attachments/assets/672e836c-4b53-41c4-8fc6-af95c254177a)
curioyang commented 1 week ago

@17168220609 nncase用的1.X版本,那 runtime有更新么? Maixpy的runtime更新不太清楚,官方cpp的可以看QQ群内的入群公告