nod-ai / SHARK-Studio

SHARK Studio -- Web UI for SHARK+IREE High Performance Machine Learning Distribution
Apache License 2.0
1.42k stars 170 forks source link

how to export tosa/linalg, when conv input and weight are torch.int8, bias is torch.int32 #2135

Open zccyman opened 6 months ago

zccyman commented 6 months ago
import torch
import torch.nn.functional as F

def export_linalg_by_shark(model, dummy_input):
    from extension.shark.shark_importer import SharkImporter

    mlir_type = "linalg"
    mlir_importer = SharkImporter(
        model,
        (dummy_input,),
        frontend="torch",
        return_str=True,
    )
    mlir_str = mlir_importer._torch_mlir(
        is_dynamic=False, tracing_required=True, mlir_type=mlir_type
    )
    with open("test.mlir", "w") as f:
        f.write(mlir_str)

class SimpleNet(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(
        self,
        x,
    ):
        weight = 100 * torch.ones(64, 12, 3, 3).type(torch.int8)
        y = F.conv2d(
            input=x.type(torch.int8),
            weight=weight,
            bias=None,
            stride=tuple([1, 1]),
            padding=(0, 0),
            dilation=tuple([1, 1]),
            groups=1,
        )

        return y

model = SimpleNet()
model.eval()

input = 100 * torch.ones(1, 12, 224, 224).type(torch.int8)
output = model(input)
export_linalg_by_shark(
    model,
    input,
)
print("test")

input, weight, bias are torch.float32, export linalg can success.