Open ievbu opened 4 years ago
Sorry for the delayed reply. CenterTrack uses two frames and a prior heatmap as input, you probably want to try something like
traced_model = torch.jit.trace(self.model, torch.rand(1, 3, 512, 512), torch.rand(1, 3, 512, 512), torch.rand(1, 1, 512, 512))
.
Hey @ievbu and @xingyizhou , I got it to work with jit.trace. The main issue is that the model forward function packages the model output within a list. You got to undo that by changing the BaseModel.py, line 91, to return out[-1]. Then, go to the Detector.py process function and remove the [-1] right after the model call. Finally, mark your jit trace as strict=False Here is the trace call.
dummy_inputs = torch.rand((1, 3, 512, 512)), torch.rand((1, 3, 512, 512)), torch.rand((1, 1, 512, 512))
self.model = torch.jit.trace(self.model, dummy_inputs, strict=False)
It is also fairly easy to update the project to the current PyTorch version and compile to half precision AND optimized_for_inference. I got a 10-15% speedup from this on a GTX 2070.
I want to load model using TorchScript scripting/tracing. However, I unsuccessfully tried:
scripted_model = torch.jit.script(model)
and got error:traced_model = torch.jit.trace(self.model, torch.rand(1, 3, 512, 512))
and also got error:Is it possible in general to save/load CenterTrack using TorchScript scripting/tracing? Maybe I just lack some knowledge of PyTorch.