uzh-rpg / svit

Official implementation of "SViT: Revisiting Token Pruning for Object Detection and Instance Segmentation"
Apache License 2.0
23 stars 4 forks source link

can not test svit-adapter-t-0.5x-ftune.py because self.num_heads is not even #1

Open Livioni opened 1 year ago

Livioni commented 1 year ago

Great work, a milestone for bringing token pruning into dense predictions.

I found that the svit-adapter-t-0.5x-ftune.py can not be tested because self.num_heads is not even.

In the InteractionBlockWithSelection class within _adaptermodules.py, when x.shape[0] != 1 (i.e., the evaluation batch size > 1), x is reshaped into a nested_tensor and passed into blk (which is TransformerEncoderLayer).

def forward(self, x, c, indexes, deform_inputs1, deform_inputs2, shape, blks, selective_modules, keep_ratio):
        n_skip = 3
        x = self.injector(query=x, reference_points=deform_inputs1[0],
                          feat=c, spatial_shapes=deform_inputs1[1],
                          level_start_index=deform_inputs1[2])
        layer_ratio_loss = 0.
        has_loss = 0
        for i in range(indexes[0], indexes[-1] + 1):
            if i < n_skip:
                x = blks[i](x)
            else:
                if self.training:
                    selector, diff_selector = selective_modules[i - n_skip](x)
                    x = diff_selector * blks[i](x, src_key_padding_mask=~selector) + \
                        (1 - diff_selector) * x
                    layer_ratio_loss += self._ratio_loss(diff_selector, keep_ratio[i - n_skip])
                    has_loss += 1
                else:
                    if x.shape[0] == 1:
                        selector, _ = selective_modules[i - n_skip](x)
                        real_indices = torch.argsort(selector.int(), dim=1, descending=True)\
                                        [:, :selector.sum(1)].unsqueeze(-1).expand(-1, -1, x.shape[-1])
                        selected_x = torch.gather(x, 1, real_indices)
                        selected_x = blks[i](selected_x)
                        x.scatter_(1, real_indices, selected_x)
                    else:
                        selector, diff_selector = selective_modules[i - n_skip](x)
                        l_aligned_x, l_aligned_mask = left_align_tokens2(x, selector)
                        nt_x = torch._nested_tensor_from_mask(l_aligned_x, l_aligned_mask, mask_check=False)
                        nt_x = blks[i](nt_x, src_key_padding_mask=None)
                        x.masked_scatter_(selector.unsqueeze(-1), torch.cat(nt_x.unbind(), 0))

        c = self.extractor(query=c, reference_points=deform_inputs2[0],
                           feat=x, spatial_shapes=deform_inputs2[1],
                           level_start_index=deform_inputs2[2], shape=shape)
        if self.extra_extractors is not None:
            for extractor in self.extra_extractors:
                c = extractor(query=c, reference_points=deform_inputs2[0],
                              feat=x, spatial_shapes=deform_inputs2[1],
                              level_start_index=deform_inputs2[2], shape=shape)
        return x, c, layer_ratio_loss, has_loss

However, it seems that MultiheadAttention in torch.nn does not support computations when self.num_heads is set to 3. How can I resolve this issue?

发生异常: AssertionError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
MultiheadAttention does not support NestedTensor outside of its fast path. The fast path was not hit because self.num_heads is not even
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1212, in forward
    assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/backbones/base/TransformerEncoderLayer.py", line 250, in _sa_block
    x = self.self_attn(x, x, x,
  File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/backbones/base/TransformerEncoderLayer.py", line 239, in forward
    x = x + self.drop_path1(self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/backbones/base/selective_vit.py", line 91, in forward
    return self.TransformerEncoderLayer(x, src_key_padding_mask=src_key_padding_mask)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/backbones/adapter_modules.py", line 251, in forward
    nt_x = blks[i](nt_x, src_key_padding_mask=None)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/backbones/selective_vit_adapter.py", line 135, in forward
    x, c, layer_ratio_loss, has_loss = layer(x, c, indexes,
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/detectors/gumbel_two_stage.py", line 18, in extract_feat
    out = self.backbone(img, need_loss)
  File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/detectors/two_stage.py", line 227, in predict
    x = self.extract_feat(batch_inputs)
  File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/detectors/base.py", line 94, in forward
    return self.predict(inputs, data_samples)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 346, in _run_forward
    results = self(**data, mode=mode)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 145, in test_step
    return self._run_forward(data, mode='predict')  # type: ignore
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/mmengine/runner/loops.py", line 454, in run_iter
    outputs = self.runner.model.test_step(data_batch)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/mmengine/runner/loops.py", line 435, in run
    self.run_iter(idx, data_batch)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/mmengine/runner/runner.py", line 1823, in test
    metrics = self.test_loop.run()  # type: ignore
  File "/home/livion/Documents/github/source/ViT_Adapter/test.py", line 145, in main
    runner.test()
  File "/home/livion/Documents/github/source/ViT_Adapter/test.py", line 149, in <module>
    main()
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
AssertionError: MultiheadAttention does not support NestedTensor outside of its fast path. The fast path was not hit because self.num_heads is not even
kaikai23 commented 11 months ago

Hi @Livioni

Thanks for your interest. Since the batching relies on nested_tensor in pytorch BetterTransformer which requires even number of heads, tiny models with 3 heads cannot have batched inference. One workaround is using bigger models as they often have even number of heads, as shown in appendix B. If you need to use batch size > 1 for tiny models, you can change num_heads to 4 when finetuning.