Open nkhlS141 opened 9 months ago
The cfg
is from a maskrcnn model, so the out
should contain segmentation mask, so adding that to res
can make Wrapper
a maskrnn model.
So you mean this would work?
def forward(self, inputs: List[torch.Tensor]): x = inputs[0].unsqueeze(0) * 255 scale = 320.0 / min(x.shape[-2], x.shape[-1]) x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True) out = self.model(x) res : Dict[str, torch.Tensor] = {} res["boxes"] = out[0] / scale res["labels"] = torch.index_select(self.coco_idx, 0, out[1]) res["masks"] = out[2] res["scores"] = out[3] return inputs, [res]
So the above changes in Wrapper class doesn't throw me any errors and the "d2go.pt" file gets created. But when I try to open this file in netron it throws error
class Wrapper(torch.nn.Module):
def init(self, model):
super().init()
self.model = model
def forward(self, inputs: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
x = inputs.unsqueeze(0) * 255
scale = 320.0 / min(x.shape[-2], x.shape[-1])
x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True)
out = self.model(x[0])
return out[0] / scale, out[1], out[2], out[3] ,
I am looking for the Wrapper class below. I have trained a maskrcnn model
orig_model = torch.jit.load(os.path.join(predictor_path, "model.jit")) wrapped_model = Wrapper(orig_model) scripted_model = torch.jit.script(wrapped_model) scripted_model.save("d2go.pt")
I found this but this seems to be for fast-rcnn models
class Wrapper(torch.nn.Module):
Any idea?