Open zccyman opened 6 months ago
import torch import torch.nn.functional as F def export_linalg_by_shark(model, dummy_input): from extension.shark.shark_importer import SharkImporter mlir_type = "linalg" mlir_importer = SharkImporter( model, (dummy_input,), frontend="torch", return_str=True, ) mlir_str = mlir_importer._torch_mlir( is_dynamic=False, tracing_required=True, mlir_type=mlir_type ) with open("test.mlir", "w") as f: f.write(mlir_str) class SimpleNet(torch.nn.Module): def __init__(self): super().__init__() def forward( self, x, ): weight = 100 * torch.ones(64, 12, 3, 3).type(torch.int8) y = F.conv2d( input=x.type(torch.int8), weight=weight, bias=None, stride=tuple([1, 1]), padding=(0, 0), dilation=tuple([1, 1]), groups=1, ) return y model = SimpleNet() model.eval() input = 100 * torch.ones(1, 12, 224, 224).type(torch.int8) output = model(input) export_linalg_by_shark( model, input, ) print("test")
input, weight, bias are torch.float32, export linalg can success.
input, weight, bias are torch.float32, export linalg can success.