naoto0804 / pytorch-AdaIN

Unofficial pytorch implementation of 'Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization' [Huang+, ICCV2017]
MIT License
1.08k stars 208 forks source link

How to convert the net model from .pth to .onnx #56

Open mhy-666 opened 1 year ago

mhy-666 commented 1 year ago

I used the following to initial the net model and save it as save.pth:

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
decoder = net.decoder
vgg = net.vgg

decoder.eval()
vgg.eval()

decoder.load_state_dict(torch.load("models/decoder.pth"))
vgg.load_state_dict(torch.load('models/vgg_normalised.pth'))
vgg = nn.Sequential(*list(vgg.children())[:31])

vgg.to(device)
decoder.to(device)

model = net.Net(vgg, decoder)
# model.load_state_dict(torch.load("models/decoder.pth"), strict=False)
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
torch.save(model.state_dict(), "save.pth")

and I load the model:

model_new.load_state_dict(torch.load("save.pth"))
model_new.eval()

batch_size=5
img=cv2.imread('input/content/avril.jpg')
input_height=img.shape[0]
input_width=img.shape[1]
input_channels=img.shape[2]
output_channels=3
output_height=512
output_width=512
dummy_input = torch.randn(1, 3, 512, 512)
input_name = 'input'
output_name = 'output'

torch.onnx.export(model_new,
                  dummy_input,
                  'AdaIN_style_transfer.onnx',
                  opset_version=11,
                  verbose = True,
                  input_names=[input_name],
                  output_names=[output_name],
                  dynamic_axes={
                      input_name: {0: 'batch_size', 1: 'input_channels', 2: 'input_height', 3: 'input_width'},
                      output_name: {0: 'batch_size', 1: 'output_channels', 2: 'output_height', 3: 'output_width'}})

but I got the following error:

TypeError: forward() missing 1 required positional argument: 'style'

How do I solve the problem?