xingyizhou / CenterTrack

Simultaneous object detection and tracking using center points.
MIT License
2.38k stars 526 forks source link

[Question] Can CenterTrack model be compiled using TorchScript? #111

Open ievbu opened 4 years ago

ievbu commented 4 years ago

I want to load model using TorchScript scripting/tracing. However, I unsuccessfully tried:

Is it possible in general to save/load CenterTrack using TorchScript scripting/tracing? Maybe I just lack some knowledge of PyTorch.

xingyizhou commented 4 years ago

Sorry for the delayed reply. CenterTrack uses two frames and a prior heatmap as input, you probably want to try something like traced_model = torch.jit.trace(self.model, torch.rand(1, 3, 512, 512), torch.rand(1, 3, 512, 512), torch.rand(1, 1, 512, 512)).

nuzrub commented 2 years ago

Hey @ievbu and @xingyizhou , I got it to work with jit.trace. The main issue is that the model forward function packages the model output within a list. You got to undo that by changing the BaseModel.py, line 91, to return out[-1]. Then, go to the Detector.py process function and remove the [-1] right after the model call. Finally, mark your jit trace as strict=False Here is the trace call.

    dummy_inputs = torch.rand((1, 3, 512, 512)), torch.rand((1, 3, 512, 512)), torch.rand((1, 1, 512, 512))
    self.model = torch.jit.trace(self.model, dummy_inputs, strict=False)

It is also fairly easy to update the project to the current PyTorch version and compile to half precision AND optimized_for_inference. I got a 10-15% speedup from this on a GTX 2070.