google-deepmind / tapnet

Tracking Any Point (TAP)
https://deepmind-tapir.github.io/blogpost.html
Apache License 2.0
1.31k stars 126 forks source link

Torchscript compatibility #83

Open SergeySandler opened 8 months ago

SergeySandler commented 8 months ago

While making the torch TAPIR model compatible with Torchscript tracing is easy by changing TAPIR.forward() in https://github.com/google-deepmind/tapnet/blob/main/torch/tapir_model.py#L196-L209 from

    out = dict(
        occlusion=torch.mean(
            torch.stack(trajectories['occlusion'][p::p]), dim=0
        ),
        tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
        expected_dist=torch.mean(
            torch.stack(trajectories['expected_dist'][p::p]), dim=0
        ),
        unrefined_occlusion=trajectories['occlusion'][:-1],
        unrefined_tracks=trajectories['tracks'][:-1],
        unrefined_expected_dist=trajectories['expected_dist'][:-1],
    )

    return out

to

    class Output(NamedTuple):
        occlusion: torch.tensor
        tracks: torch.tensor
        expected_dist: torch.tensor

    out = Output(torch.mean(torch.stack(trajectories['occlusion'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['expected_dist'][p::p]), dim=0)
                )

    return out

(assuming it is OK to eliminate unrefined_ from the output), so that

model = tapir_model.TAPIR(pyramid_level=1)
model.load_state_dict(torch.load('bootstapir_checkpoint.pt'))
model = model.to(torch.device('cpu'))
model.eval()
dummy_input_frames = torch.randn(1, 32, 256, 256, 3, dtype=torch.float32, device = torch.device('cpu'))
dummy_input_query_points = torch.randn(1, 20, 3, dtype=torch.float32, device = torch.device('cpu'))    
scriptModule = torch.jit.trace(model, (dummy_input_frames, dummy_input_query_points))
torch.jit.save(scriptModule, 'bootstapir_checkpoint.ptc')

succeeds, it is not so easy to make it Torchscript scripting compatible.

scriptModule = torch.jit.script(model)

fails with

Module 'BlockV2' has no attribute 'proj_conv' :
  File "C:\tapnet\tapnet\torch\nets.py", line 278
    x = torch.relu(x)
    if self.use_projection:
      shortcut = self.proj_conv(x)
                 ~~~~~~~~~~~~~~ <--- HERE

How to make the model Torchscript scripting compatible?

SergeySandler commented 8 months ago

It seems to be possible to overcome the error reported above by modifying BlockV2.__init__ by adding an else clause after

    if self.use_projection:
      self.proj_conv = nn.Conv2d(
          in_channels=channels_in,
          out_channels=channels_out,
          kernel_size=1,
          stride=stride,
          padding=0,
          bias=False,
      )

in https://github.com/google-deepmind/tapnet/blob/main/torch/nets.py#L225-L233, so it looks like

   if self.use_projection:
      self.proj_conv = nn.Conv2d(...)
   else:
      self.proj_conv = DummyModel()

where DummyModel is dummy:

class DummyModel:

    def __init__(self):
        pass

    def forward(self):
        return torch.tensor(0)

    def __call__(self, input):
        return self.forward()

But then torch.jit.script(model) fails with

Arguments for call are not valid.
The following variants are available:

  aten::cat(Tensor[] tensors, int dim=0) -> Tensor:
  Keyword argument axis unknown.

  aten::cat.names(Tensor[] tensors, str dim) -> Tensor:
  Argument dim not provided.

  aten::cat.names_out(Tensor[] tensors, str dim, *, Tensor(a!) out) -> Tensor(a!):
  Argument dim not provided.

  aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!):
  Argument out not provided.

The original call is:
  File "C:\tapnet\tapnet\torch\nets.py", line 61
    prev_frame = torch.cat([x[0:1], x[:-1]], dim=0)
    next_frame = torch.cat([x[1:], x[-1:]], dim=0)
    resid = torch.cat([x, prev_frame, next_frame], axis=1) 
            ~~~~~~~~~ <--- HERE

that can be resolved by replacing resid = torch.cat([x, prev_frame, next_frame], axis=1) with resid = torch.cat([x, prev_frame, next_frame], dim=1) . I'd like to know why does not it cause 'axis' an unexpected keyword argument error? The next error that happens is the following:

Unknown type constructor Mapping:
  File "C:\tapnet\tapnet\torch\tapir_model.py", line 145
      get_query_feats: bool = False,
      refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
  ) -> Mapping[str, torch.Tensor]:
       ~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
sgjheywa commented 8 months ago

Hi,

Thanks for raising the issue, prior to release we were also able to trace the model using the same method you described but after testing it actually showed very little performance increase when used. Can I ask what the use case is for scripting here? Thanks

SergeySandler commented 8 months ago

@sgjheywa, scripting (torch.jit.script) helps to save a model with dynamic dimensions, while only static dimensions are supported through tracing. There were many code changes to achieve JIT compatibility, please review https://github.com/google-deepmind/tapnet/pull/85.

sgjheywa commented 8 months ago

Sorry, I am familiar with scripting, I'm just trying to figure out what the use case is here. Since the model is compatible with torch.compile this seems unnecessary. Thanks

SergeySandler commented 8 months ago

@sgjheywa, the use case is LibTorch integration in C++. The model can be compiled with torch.compile, but it does not help since you cannot save it with torch.jit.save. Am I missing something? Thank you.

pubyLu commented 6 months ago

While making the torch TAPIR model compatible with Torchscript tracing is easy by changing TAPIR.forward() in https://github.com/google-deepmind/tapnet/blob/main/torch/tapir_model.py#L196-L209 from

    out = dict(
        occlusion=torch.mean(
            torch.stack(trajectories['occlusion'][p::p]), dim=0
        ),
        tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
        expected_dist=torch.mean(
            torch.stack(trajectories['expected_dist'][p::p]), dim=0
        ),
        unrefined_occlusion=trajectories['occlusion'][:-1],
        unrefined_tracks=trajectories['tracks'][:-1],
        unrefined_expected_dist=trajectories['expected_dist'][:-1],
    )

    return out

to

    class Output(NamedTuple):
        occlusion: torch.tensor
        tracks: torch.tensor
        expected_dist: torch.tensor

    out = Output(torch.mean(torch.stack(trajectories['occlusion'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['expected_dist'][p::p]), dim=0)
                )

    return out

(assuming it is OK to eliminate unrefined_ from the output), so that

model = tapir_model.TAPIR(pyramid_level=1)
model.load_state_dict(torch.load('bootstapir_checkpoint.pt'))
model = model.to(torch.device('cpu'))
model.eval()
dummy_input_frames = torch.randn(1, 32, 256, 256, 3, dtype=torch.float32, device = torch.device('cpu'))
dummy_input_query_points = torch.randn(1, 20, 3, dtype=torch.float32, device = torch.device('cpu'))    
scriptModule = torch.jit.trace(model, (dummy_input_frames, dummy_input_query_points))
torch.jit.save(scriptModule, 'bootstapir_checkpoint.ptc')

succeeds, it is not so easy to make it Torchscript scripting compatible.

scriptModule = torch.jit.script(model)

fails with

Module 'BlockV2' has no attribute 'proj_conv' :
  File "C:\tapnet\tapnet\torch\nets.py", line 278
    x = torch.relu(x)
    if self.use_projection:
      shortcut = self.proj_conv(x)
                 ~~~~~~~~~~~~~~ <--- HERE

How to make the model Torchscript scripting compatible?

hello! May I ask if you have implemented model training for the Tapir Python version