Open kaczmarj opened 1 year ago
I tried with this code so far but i am getting a TypeError.
import torch
from model_arch import UnetVggMultihead
model = UnetVggMultihead(load_weights=False)
torch.jit.script(model)
this is the error
TypeError:
'numpy.int64' object in attribute 'Conv2d.out_channels' is not a valid constant.
Valid constants are:
1. a nn.ModuleList
2. a value of type {bool, float, int, str, NoneType, torch.device, torch.layout, torch.dtype}
3. a list or tuple of (2)
i got it working with torch.jit.trace
.
import torch
from model_arch import UnetVggMultihead
model = UnetVggMultihead(load_weights=False)
script = torch.jit.trace(model, torch.ones(1,3,448,448))
# Freeze the model, and do some other optimizations...
script_opt = torch.jit.optimize_for_inference(script)
# Save
torch.jit.save(script_opt, "model-torchscript-opt.pth")
hi @ShahiraAbousamra - I would like to convert the trained model of mcspatnet to torchscript so we can use it in qupath. i'm opening this issue to track my progress.