onnx / optimizer

Actively maintained ONNX Optimizer
Apache License 2.0
650 stars 90 forks source link

fix missing functions #137

Open AlexHJH opened 1 year ago

AlexHJH commented 1 year ago

use this code to test

import timm
import torch
import torch.nn as nn

with timm.set_exportable(True):
    module = timm.create_model("vit_tiny_r_s16_p8_224")
    torch.onnx.export(
        module,
        (torch.ones([1, 3, 224, 224], dtype=torch.float32), ),
        "test.onnx",
        opset_version=15,
        export_modules_as_functions={
            nn.LayerNorm,
            nn.GELU
        }
    )

import onnxoptimizer
import onnx

model = onnx.load("test.onnx")
opt_model = onnxoptimizer.optimize(model)
onnx.save(opt_model, "opt.onnx")
assert len(model.functions) == len(opt_model.functions)