Open korotaS opened 3 months ago
@korotaS thank you for the report!
You are absolutely right. Note, we haven't changed the original implementations (everything which is under external folder is a copy-paste). The problem is indeed in the interpolation.
As for the solution on the library side: I think it may be a compromise. Let's add an interpolation, but also raise warning if image size assumes we need to apply this interpolation. In other words, we warn user if code works a bit not as expected in the original implementation.
There is also a problem we need to solve: the check will be placed in forward()
, but we want to avoid thousands of warning, so, we only need to warn user once. We can use a decorator for this purpose. Here is a draft:
import warnings
def warn_if_even(func):
def wrapper(n):
if n % 2 == 0:
warnings.warn(f"Input value {n} is an even number")
return func(n)
return wrapper
@warn_if_even
def my_function(n):
# Your function logic here
print(f"Function called with argument: {n}")
# Example of calling the function with different values
for i in range(1, 6):
my_function(i)
Do you want to make this contribution to the library, @korotaS ? I would greatly appreciate it :)
Do you think that interpolating is better than changing some layers? I know that it is not a good idea to replace trained layers with non-trained ones but it may be that they can learn quickly (especially positional encoding).
As for this:
Do you want to make this contribution to the library, @korotaS?
I think I can, maybe not today and not tomorrow but I will let you know asap.
@korotaS
In Dino they also trained on some set of fixed images sizes, but on inference time they allow to interpolate
We have a simple way to check if everything is fine. Just run validation on any of ours benchmarks with im_size=360
and compare the results with the ones provided in the zoo table.
Hi! I am using this example to train a ConcatSiamese model and if I use an extractor from example (
vits16_dino
) - it runs ok but if I use different model, for examplevitb32_unicom
I get an error. Here is an example for reproducing:And here is the last traceback item:
I think the problem is that in
vits16_dino
there is an interpolation before positional embedding, so the tensor is downsampled into preferred shape: https://github.com/OML-Team/open-metric-learning/blob/05842c84c980fa2c34510e3de176cd5c5dae9205/oml/models/vit_dino/external/vision_transformer.py#L260-L280But in
vitb32_unicom
there is no such interpolation, so the input tensor is 2 times bigger than positional encoding expects, so we need to manually replace some layers that depend on number of patches (model.pos_embed
andmodel.feature[0]
).I think that this information needs to be added in docs or perhaps handled in
ViTUnicomExtractor
or inConcatSiamese
with some kind of warning. Also, this error reproduces with all otherViTUnicomExtractor
andViTCLIPExtractor
models.