huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
31.02k stars 4.66k forks source link

[FEATURE] Support resizing rectangular pos_embed #2190

Open gau-nernst opened 2 months ago

gau-nernst commented 2 months ago

Is your feature request related to a problem? Please describe.

Currently when changing ViT img size from a rectangular size, resample_abs_pos_embed() does not work correctly since it does not know the original rectangular size and assume a square.

https://github.com/huggingface/pytorch-image-models/blob/5dce71010174ad6599653da4e8ba37fd5f9fa572/timm/models/vision_transformer.py#L1096-L1103

https://github.com/huggingface/pytorch-image-models/blob/5dce71010174ad6599653da4e8ba37fd5f9fa572/timm/layers/pos_embed.py#L32-L34

Describe the solution you'd like

It should work out of the box.

Describe alternatives you've considered

Manually resize it.

Additional context

Apparently dynamic img size also will not work when original img size is rectangle.

https://github.com/huggingface/pytorch-image-models/blob/5dce71010174ad6599653da4e8ba37fd5f9fa572/timm/models/vision_transformer.py#L603-L609

This is a rare problem since most image ViT use square inputs. The particular model I'm using is my previously ported AudioMAE (https://huggingface.co/gaunernst/vit_base_patch16_1024_128.audiomae_as2m), which uses rectangular input (mel-spectrogram).

I understand it is not so straight-forward to support this, since once the model is created (with updated image size), the original image size is lost. Some hacks can probably bypass this, but not so nice

  1. Propagate the original image size to the _load_weights() function
  2. Create a model with the original image size, load weights as usual. Add a new method like .set_img_size() which will update the internal img_size attribute and resamle pos embed.

Perhaps an easier solution is to fix dynamic img size to pass the original img size (which I tested locally and works)

        if self.dynamic_img_size:
            B, H, W, C = x.shape
            pos_embed = resample_abs_pos_embed(
                self.pos_embed,
                (H, W),
                self.patch_embed.grid_size,
                num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
            )
rwightman commented 2 months ago

@gau-nernst I have thought about this, and yeah it's not done because to do it at weight load time where it's done now, it's more complexity than I'd like, especially for the benefit/demand.

That said, I need a set_img_size() like fn for another project, so it will be supported through that mechanism. I was planning to do something similar to the swin_v2_cr implementation (https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/swin_transformer_v2_cr.py#L324-L337) but tweak naming a bit...

I believe dynamic resize can be made to work, it needs the 'old size' arg to be non-default...

rwightman commented 2 weeks ago

@gau-nernst on PR #2225 there is first pass at implementing a set_input_size fn... currently should mostly work for models from vision_transformer.py, vision_transformer_hybrid.py, and swin_transformer.py (v1)