fbelderink / flutter_pytorch_mobile

A flutter plugin for pytorch model inference. Supports image models as well as custom models.
https://pub.dev/packages/pytorch_mobile
Other
101 stars 52 forks source link

Your input type float32 (Float) does not match with model input type #27

Closed KonstiDE closed 1 year ago

KonstiDE commented 1 year ago

Hello, I was playing around a bit and was trying to recreate the pytorch code for the custom_model in the example with:

import torch
import torch.nn as nn
import torch.utils.mobile_optimizer as mobile
from torchsummary import summary

class FullyNet(nn.Module):
    def __init__(self):
        super(FullyNet, self).__init__()

        self.fc = nn.Linear(4, 1)

    def forward(self, x):
        x = x.view(-1)

        return self.fc(x)

if __name__ == '__main__':
    net = FullyNet()

    net.eval()
    quantized_model = torch.quantization.convert(net)
    scripted_model = torch.jit.script(quantized_model)
    opt_model = mobile.optimize_for_mobile(scripted_model)
    opt_model.save('fully_net.pt')

    out = net(torch.randn(1, 2, 2))
    print(out)

However I get Your input type float32 (Float) does not match with model input type I can also print the full stacktrace below... Is there some way to fix that?

Full stack trace: https://hastebin.com/share/ibuvorukuq.css

PS: If I use the custom_model for from your examples, it works like a charm :)

KonstiDE commented 1 year ago

Fixed it, save your model like he does:

example = torch.rand(1, 2, 2)
traced_script_module = torch.jit.trace(net, example)
traced_script_module.save("fully_net.pt")