facebookresearch / d2go

D2Go is a toolkit for efficient deep learning
Apache License 2.0
826 stars 197 forks source link

MaskRCNN Wrapper #624

Open nkhlS141 opened 9 months ago

nkhlS141 commented 9 months ago

I am looking for the Wrapper class below. I have trained a maskrcnn model

orig_model = torch.jit.load(os.path.join(predictor_path, "model.jit")) wrapped_model = Wrapper(orig_model) scripted_model = torch.jit.script(wrapped_model) scripted_model.save("d2go.pt")

I found this but this seems to be for fast-rcnn models

class Wrapper(torch.nn.Module):

def __init__(self, model):
    super().__init__()
    self.model = model
    coco_idx_list = [1]

    self.coco_idx = torch.tensor(coco_idx_list)

def forward(self, inputs: List[torch.Tensor]):
    x = inputs[0].unsqueeze(0) * 255
    scale = 320.0 / min(x.shape[-2], x.shape[-1])
    x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True)
    out = self.model(x[0])
    res : Dict[str, torch.Tensor] = {}
    res["boxes"] = out[0] / scale
    res["labels"] = torch.index_select(self.coco_idx, 0, out[1])
    res["scores"] = out[2]
    return inputs, [res]

Any idea?

wat3rBro commented 9 months ago

The cfg is from a maskrcnn model, so the out should contain segmentation mask, so adding that to res can make Wrapper a maskrnn model.

nkhlS141 commented 9 months ago

So you mean this would work?

def forward(self, inputs: List[torch.Tensor]): x = inputs[0].unsqueeze(0) * 255 scale = 320.0 / min(x.shape[-2], x.shape[-1]) x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True) out = self.model(x) res : Dict[str, torch.Tensor] = {} res["boxes"] = out[0] / scale res["labels"] = torch.index_select(self.coco_idx, 0, out[1]) res["masks"] = out[2] res["scores"] = out[3] return inputs, [res]

nkhlS141 commented 9 months ago

So the above changes in Wrapper class doesn't throw me any errors and the "d2go.pt" file gets created. But when I try to open this file in netron it throws error

rochist commented 7 months ago

class Wrapper(torch.nn.Module): def init(self, model): super().init() self.model = model def forward(self, inputs: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
x = inputs.unsqueeze(0) * 255 scale = 320.0 / min(x.shape[-2], x.shape[-1]) x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True) out = self.model(x[0]) return out[0] / scale, out[1], out[2], out[3] ,