Open Livioni opened 1 year ago
Hi @Livioni
Thanks for your interest. Since the batching relies on nested_tensor in pytorch BetterTransformer which requires even number of heads, tiny models with 3 heads cannot have batched inference. One workaround is using bigger models as they often have even number of heads, as shown in appendix B. If you need to use batch size > 1 for tiny models, you can change num_heads to 4 when finetuning.
Great work, a milestone for bringing token pruning into dense predictions.
I found that the svit-adapter-t-0.5x-ftune.py can not be tested because self.num_heads is not even.
In the InteractionBlockWithSelection class within _adaptermodules.py, when x.shape[0] != 1 (i.e., the evaluation batch size > 1), x is reshaped into a nested_tensor and passed into blk (which is TransformerEncoderLayer).
However, it seems that MultiheadAttention in torch.nn does not support computations when self.num_heads is set to 3. How can I resolve this issue?