TopoXLab / MCSpatNet

Repository for ICCV2021 MCSpatNet: Multi-Class Cell Detection Using Spatial Context Representation
MIT License
20 stars 8 forks source link

Convert the trained model to torchscript #3

Open kaczmarj opened 1 year ago

kaczmarj commented 1 year ago

hi @ShahiraAbousamra - I would like to convert the trained model of mcspatnet to torchscript so we can use it in qupath. i'm opening this issue to track my progress.

kaczmarj commented 1 year ago

I tried with this code so far but i am getting a TypeError.

import torch
from model_arch import UnetVggMultihead

model = UnetVggMultihead(load_weights=False)
torch.jit.script(model)

this is the error


TypeError: 
'numpy.int64' object in attribute 'Conv2d.out_channels' is not a valid constant.
Valid constants are:
1. a nn.ModuleList
2. a value of type {bool, float, int, str, NoneType, torch.device, torch.layout, torch.dtype}
3. a list or tuple of (2)
kaczmarj commented 1 year ago

i got it working with torch.jit.trace.

import torch
from model_arch import UnetVggMultihead

model = UnetVggMultihead(load_weights=False)
script = torch.jit.trace(model, torch.ones(1,3,448,448))
# Freeze the model, and do some other optimizations...
script_opt = torch.jit.optimize_for_inference(script)
# Save
torch.jit.save(script_opt, "model-torchscript-opt.pth")