facebookresearch / dinov2

PyTorch code and models for the DINOv2 self-supervised learning method.
Apache License 2.0
9k stars 791 forks source link

Optimizing Image Preprocessing for the Dinov2 Model: Balancing Accuracy and Information Retention #86

Closed aaiguy closed 1 year ago

aaiguy commented 1 year ago

I'm using the Dinov2 model to extract features from images before passing them to the model. For preprocessing the images, I'm following the same procedure as used during the training of the Dinov2 pretrained model on the Imagenet dataset, as shown in the code below

import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# Load the image
filename = r"C:\Users\NLRH1411\Downloads\istockphoto-1391649784-170667a.jpg"
img = Image.open(filename)

# Define the transformation pipeline
transform = T.Compose([
    T.Resize(256 ,interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

# Apply the transformation
if img.mode == 'RGBA':
    img.load()
    background = Image.new("RGB", img.size, (255, 255, 255))
    background.paste(img, mask=img.split()[3])
    img = background
img_tensor = transform(img)[:3].unsqueeze(0)

# Reverse normalization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
output_img = img_tensor.squeeze().numpy().transpose((1, 2, 0))
output_img = std * output_img + mean
output_img = np.clip(output_img, 0, 1)

# Create the figure and subplots
fig, axs = plt.subplots(1, 2)

# Plot the input image in the first subplot
axs[0].imshow(img)
axs[0].set_title('Input Image')

# Plot the output image in the second subplot
axs[1].imshow(output_img)
axs[1].set_title('After Image Preprocessing')

# Display the figure
plt.show()

output preprocessing image

as u can see output of preprocessed image some portion of image get cropped which is affecting the output result as the some portion of image information getting lost, is there a way to avoid this without loosing any image information while doing image preprocessing without affecting accuracy?

aaiguy commented 1 year ago

any thought on this @woctezuma ?

woctezuma commented 1 year ago

This is due to the center-cropping operation:

https://github.com/facebookresearch/dinov2/blob/c3c2683a13cde94d4d99f523cf4170384b00c34c/dinov2/data/transforms.py#L87-L88

I don't know how it would affect accuracy to resize to 224 instead of resizing to 256 and then center-cropping to 224 resolution.

However, depending on the task, it could make sense. If you look at DINO (the first version):

See:

aaiguy commented 1 year ago

Yes,its affecting expeciallt when image objects are big or inverted, is it possible to resize image to 224 without loosing any portion of object in image and quality ? just doing resize to 224 will shrink the object

ccharest93 commented 1 year ago

I haven't experimented with this yet but the embedding interpolation should allow for non-square images. So the centercrop is not necessary, given initial image of H W and patch size of 14 you could crop to (H - H%14) (W - W%14) and feed that as input to the network. (could also resize before doing that crop if you want less patches).

TimDarcet commented 1 year ago

as u can see output of preprocessed image some portion of image get cropped which is affecting the output result as the some portion of image information getting lost, is there a way to avoid this without loosing any image information while doing image preprocessing without affecting accuracy?

You are correct, the preprocessing does do that. The effect of this can be positive or negative depending on many factors, we did not really tweak this honestly. My guess would be that it's good only on object-centric datasets such as Imagenet.

You can make the crop less aggressive by doing the resize to 224, as @woctezuma suggested

If you want to keep all of the image, you can either:

It's hard to say which would be best. You'd have to test to see what works best for you. In any case, I don't think it should affect performance that much.

patricklabatut commented 1 year ago

Closing as answered, please re-open if you need to discuss this more.

smandava98 commented 1 year ago

@TimDarcet I have similar issues as OP. All of my images are rectangular and I found that, after visualizing the features via PCA, that rectangle images don't work well and even resizing the images explicitly does not work well either. I tried padding to a square image and it also does pretty poorly (you can barely see any objects with these conditions). What consistently works well seems to be what is used with ImageNet (resize with bicubic interp., center crop, and using ImageNet norms).

However, my issue is that I need all of the image (even the details at the end) because my images are medical images so it has lots of important info all around the image. Do you have a recommended preprocessing method to deal with rectangular images while still preserving edge details?

smandava98 commented 1 year ago

Let me know any recommendations cc @woctezuma

TimDarcet commented 1 year ago

Hi, could you give a bit more detail? A few questions:

A few insights:

smandava98 commented 1 year ago

All my images are of size (640,368) mainly because they are actually video frames taken from 360p videos of cells/protein structures/etc..

I use this transform:

H = 640
W = 368
patch_size = 14
newH = H - H%patch_size
newW = W - W%patch_size
print(f"New widths and heights are {newW,newH}")

transform = T.Compose([
    T.Resize((newH, newW), interpolation=T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

And I get this visualization like this with PCA (n=3):

Screen Shot 2023-10-10 at 2 30 28 AM

When I use this transform (center crop aka square image):

H = 640
W = 368
patch_size = 14
newH = H - H%patch_size
newW = W - W%patch_size
print(f"New widths and heights are {newW,newH}")

transform = T.Compose([
    T.Resize((newH, newW), interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(newW),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

I get something like this:

Screen Shot 2023-10-10 at 2 30 56 AM

With the center transform, I can clearly see the objects. In fact, I'm pretty impressed because it's able to get features for basically all the objects. But with the original rectangular image, it is a bit blurry and there is some weird color distortion going on (bottom half is very light)

When I visualize an attention head I can see that DINOv2 is paying attention to various features (I can see objects clearly in some images) with those rectangular images. Here is an example of one attention map:

Screen Shot 2023-10-10 at 2 34 58 AM

Even though rectangular images are OOD it seems the attention mechanisms suggest the model's internal representations are still valuable.

This is the root cause of my confusion. PCA visualization does not give good results, attention results seem decent, so I am not sure if the model is giving useful representations overall or not for my data.

Center cropping does not make sense for my use case because valuable info is contained everywhere. My end purpose is that I am hoping to do cell-based segmentation.

And yeah I tested padding and it does not do well at all. I've narrowed down to using ImageNet norms (performs decently well for medical actually) and just resizing with no cropping.

TimDarcet commented 1 year ago

But with the original rectangular image, it is a bit blurry and there is some weird color distortion going on (bottom half is very light)

The PCA you have seems okay to me, it's just that you are only visualizing the first component here. Showing (separately) a few other components (eg the first 10) should show the structure you are looking for.

I am not sure if the model is giving useful representations overall or not for my data.

As you said, the attmap looks okay, so the model is doing something. I think the representation should have some value.