Open AlexHJH opened 1 year ago
use this code to test
import timm import torch import torch.nn as nn with timm.set_exportable(True): module = timm.create_model("vit_tiny_r_s16_p8_224") torch.onnx.export( module, (torch.ones([1, 3, 224, 224], dtype=torch.float32), ), "test.onnx", opset_version=15, export_modules_as_functions={ nn.LayerNorm, nn.GELU } ) import onnxoptimizer import onnx model = onnx.load("test.onnx") opt_model = onnxoptimizer.optimize(model) onnx.save(opt_model, "opt.onnx") assert len(model.functions) == len(opt_model.functions)
use this code to test