Open mhy-666 opened 1 year ago
I used the following to initial the net model and save it as save.pth:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu") decoder = net.decoder vgg = net.vgg decoder.eval() vgg.eval() decoder.load_state_dict(torch.load("models/decoder.pth")) vgg.load_state_dict(torch.load('models/vgg_normalised.pth')) vgg = nn.Sequential(*list(vgg.children())[:31]) vgg.to(device) decoder.to(device) model = net.Net(vgg, decoder) # model.load_state_dict(torch.load("models/decoder.pth"), strict=False) device = torch.device('cuda' if torch.cuda.is_available() else "cpu") torch.save(model.state_dict(), "save.pth")
and I load the model:
model_new.load_state_dict(torch.load("save.pth")) model_new.eval() batch_size=5 img=cv2.imread('input/content/avril.jpg') input_height=img.shape[0] input_width=img.shape[1] input_channels=img.shape[2] output_channels=3 output_height=512 output_width=512 dummy_input = torch.randn(1, 3, 512, 512) input_name = 'input' output_name = 'output' torch.onnx.export(model_new, dummy_input, 'AdaIN_style_transfer.onnx', opset_version=11, verbose = True, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: 'batch_size', 1: 'input_channels', 2: 'input_height', 3: 'input_width'}, output_name: {0: 'batch_size', 1: 'output_channels', 2: 'output_height', 3: 'output_width'}})
but I got the following error:
TypeError: forward() missing 1 required positional argument: 'style'
How do I solve the problem?
I used the following to initial the net model and save it as save.pth:
and I load the model:
but I got the following error:
How do I solve the problem?