OML-Team / open-metric-learning

Metric learning and retrieval pipelines, models and zoo.
https://open-metric-learning.readthedocs.io/en/latest/index.html
Apache License 2.0
867 stars 62 forks source link

Support positional encoding interpolation in Zoo models #601

Open korotaS opened 3 months ago

korotaS commented 3 months ago

Hi! I am using this example to train a ConcatSiamese model and if I use an extractor from example (vits16_dino) - it runs ok but if I use different model, for example vitb32_unicom I get an error. Here is an example for reproducing:

device = 'cpu'
extractor = ViTUnicomExtractor.from_pretrained("vitb32_unicom").to(device)
transforms, _ = get_transforms_for_pretrained("vitb32_unicom")
pairwise_model = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100], device=device)

out = pairwise_model(x1=torch.rand(2, 3, 224, 224), x2=torch.rand(2, 3, 224, 224))

And here is the last traceback item:

File /home/korotas/projects/open-metric-learning/oml/models/vit_unicom/external/vision_transformer.py:181, in VisionTransformer.forward_features(self, x)
    179 B = x.shape[0]
    180 x = self.patch_embed(x)
--> 181 x = x + self.pos_embed
    182 for func in self.blocks:
    183     x = func(x)

RuntimeError: The size of tensor a (98) must match the size of tensor b (49) at non-singleton dimension 1

I think the problem is that in vits16_dino there is an interpolation before positional embedding, so the tensor is downsampled into preferred shape: https://github.com/OML-Team/open-metric-learning/blob/05842c84c980fa2c34510e3de176cd5c5dae9205/oml/models/vit_dino/external/vision_transformer.py#L260-L280

But in vitb32_unicom there is no such interpolation, so the input tensor is 2 times bigger than positional encoding expects, so we need to manually replace some layers that depend on number of patches (model.pos_embed and model.feature[0]).

I think that this information needs to be added in docs or perhaps handled in ViTUnicomExtractor or in ConcatSiamese with some kind of warning. Also, this error reproduces with all other ViTUnicomExtractor and ViTCLIPExtractor models.

AlekseySh commented 3 months ago

@korotaS thank you for the report!

You are absolutely right. Note, we haven't changed the original implementations (everything which is under external folder is a copy-paste). The problem is indeed in the interpolation.

As for the solution on the library side: I think it may be a compromise. Let's add an interpolation, but also raise warning if image size assumes we need to apply this interpolation. In other words, we warn user if code works a bit not as expected in the original implementation.

There is also a problem we need to solve: the check will be placed in forward(), but we want to avoid thousands of warning, so, we only need to warn user once. We can use a decorator for this purpose. Here is a draft:

import warnings

def warn_if_even(func):
    def wrapper(n):
        if n % 2 == 0:
            warnings.warn(f"Input value {n} is an even number")
        return func(n)
    return wrapper

@warn_if_even
def my_function(n):
    # Your function logic here
    print(f"Function called with argument: {n}")

# Example of calling the function with different values
for i in range(1, 6):
    my_function(i)

Do you want to make this contribution to the library, @korotaS ? I would greatly appreciate it :)

korotaS commented 3 months ago

Do you think that interpolating is better than changing some layers? I know that it is not a good idea to replace trained layers with non-trained ones but it may be that they can learn quickly (especially positional encoding).

As for this:

Do you want to make this contribution to the library, @korotaS?

I think I can, maybe not today and not tomorrow but I will let you know asap.

AlekseySh commented 3 months ago

@korotaS

In Dino they also trained on some set of fixed images sizes, but on inference time they allow to interpolate

We have a simple way to check if everything is fine. Just run validation on any of ours benchmarks with im_size=360 and compare the results with the ones provided in the zoo table.