allenai / mmc4

MultimodalC4 is a multimodal extension of c4 that interleaves millions of images with text.
MIT License
887 stars 33 forks source link

CLIP ViT-L/14 weights #12

Closed josep-alana closed 1 year ago

josep-alana commented 1 year ago

Hi, I was wondering which weights you used for computing the image features? So far, I tried with the HF ones and the openAI ones but the feature vectors I get for the images are significantly different from the precomputed ones you shared. I can share some minimal code of what I tried if it helps. Thanks!

jmhessel commented 1 year ago

Hi! Quickly responding: thanks for the heads up! We extracted features on TPU using a jax port of CLIP, which has previously been verified to output the same as the openai version. Issues like f16/bf16/f32 may cause numerical differences. Similarly, prepossessing may work slightly differently on TPU , which I am checking on now. When you say they are significantly different, I assume this is beyond numerical stability?

jmhessel commented 1 year ago

Hi! Posting some initial thoughts. @VegB and I can take a closer look at this soon.

I don't think there's a bug per-say, but yes, there are differences between the default openai extraction setting, and what you get from the TPU extraction.

jmhessel commented 1 year ago

here's a snippet of my code for fp16 openai results above. I ran on device='cuda'.

class CLIPImageDataset(torch.utils.data.Dataset):
    def __init__(self, data, preprocess):
        self.data = data
        self.preprocess = preprocess

    def __getitem__(self, idx):
        c_data = self.data[idx]
        image = Image.open(c_data)
        image = self.preprocess(image)
        return {'image':image}

    def __len__(self):
        return len(self.data)

def extract_all_images(images, model, preprocess, device, batch_size=256, num_workers=8):
    data = torch.utils.data.DataLoader(
        CLIPImageDataset(images, preprocess),
        batch_size=batch_size, num_workers=num_workers, shuffle=False)
    all_image_features = []
    with torch.no_grad():
        for b in tqdm.tqdm(data):
            b = b['image'].to(device)
            if device == 'cuda':
                b = b.to(torch.float16)
            all_image_features.append(model.encode_image(b).cpu().numpy())
    all_image_features = np.vstack(all_image_features)
    return all_image_features

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)
model.eval()
josep-alana commented 1 year ago

Hi, thanks for the quick reply and the code snippet! To answer your initial question, here is a comparison of the precomputed features and the features I get for image_id = f5f8113da82f.jpg (https://i.kinja-img.com/gawker-media/image/upload/s--3J8F-YDp--/c_fit,f_auto,fl_progressive,q_80,w_320/836357908648846882.jpg) using both the Hugging Face model and the OpenAI model. I did expect these two to yield different results since I believe the people from HF did their own training and so the model weights are likely different but I still see a big difference between the precomputed features and those I get using the OpenAI code (I'm only showing the first few values):

I used a GeForce RTX 3080 GPU.

P.S.: Looking at your code snippet, I realised that I wasn't calling the model eval() method, and I also added a casting of the input image to fp16 but then the result I get with the OpenAI model is just slightly different to the ones I was getting before (I think this difference may be due to numerical stability):

[1.7432e-01, -2.8174e-01, -1.0083e-01,  3.8306e-01,
-4.7412e-01, -1.5540e-01,  5.3741e-02,  6.6772e-02,
-7.2998e-01, -4.6558e-01, -6.3281e-01, -1.1438e-01,
1.1348e+00,  2.2621e-03,  3.0713e-01, 5.6152e-01, ...]
jmhessel commented 1 year ago

Quick question: which shard is that image from?

josep-alana commented 1 year ago

Oh, sorry, shard 0; this is the snippet of code that I used:

import clip
import json
import torch
import pickle
import requests
from PIL import Image

mmc4_data = '/storage/datasets/mmc4/docs_no_face_shard_0_v2.jsonl'
with open(mmc4_data, 'r') as f:
        line = f.readline()
        data = json.loads(line)

precomp_features = pickle.load(
    open('/storage/datasets/mmc4/clip_vitl14_shard_0_features.pkl', 'rb')
)

image_id = data['image_info'][0]['image_name']
image_url = data['image_info'][0]['raw_url']
image = Image.open(requests.get(image_url, stream=True).raw)

model, processor = clip.load('ViT-L/14', device='cuda')
model.eval()

with torch.no_grad():
    inputs = processor(image).unsqueeze(0).to('cuda').to(torch.float16)
    features = model.encode_image(inputs).cpu().numpy()

print(precomp_features[image_id])
print(features)
jmhessel commented 1 year ago

Thanks! When I recompute, I get:

# fp16 openai
[ 0.5537  -0.268    0.07825  0.593   -0.7036 ]
# our released features
[ 0.55724317 -0.25615692  0.09603771  0.58884245 -0.6886256 ]

but this is based on a downloaded version of the image. It's possible the image has changed. I redownloaded the jpg and ran again:

[0.1747, -0.2837, -0.1021, 0.3813, -0.4746]

Which matches your result. So -- the difference appears to be from the images themselves. Here they are:

This is the version from which the mmc4 image features are computed:

from_local_copy

This is the version from the web:

from_web

They look identical to me, but when I compute the md5 hash they are indeed different files.

137abb1e7c7181783f9b5a55c76e0a90 from_local_copy.jpg 3b382f143807db5c54573b9005440e75 from_web.jpg

Let me think about this...

jmhessel commented 1 year ago

Ok, I solved the mystery. We used this library to do image resizing:

https://github.com/ImageMagick/ImageMagick

specifically, we ran this command which resizes the max edge to at most 800.

mogrify -resize "800x800>" *

If I run that command on the image from the web, and then run CLIP again, I can reproduce the extracted feature:

# raw from web
[ 0.175   -0.2808  -0.09985  0.3816  -0.4724 ]
# after running mogrify -resize "800x800>"
[ 0.5537  -0.268    0.07825  0.593   -0.7036 ]
# released feature
[ 0.5537  -0.268    0.07825  0.593   -0.7036 ]

the exact version of the library that reproduces this is:

$ mogrify --version
Version: ImageMagick 6.9.10-23 Q16 x86_64 20190101 https://imagemagick.org

This is interesting that this operation for this image changes the CLIP feature as much as it does, I wouldn't have guessed that.

josep-alana commented 1 year ago

Wow, that's crazy! Although I'm not that surprised since I've encountered issues of a similar nature before :smile: I'm curious, why do you resize the images beforehand? Doesn't CLIP's processing step resize images to have a size of 224x224?

At any rate, I'm glad you solved the mystery! :tada: Thanks!

jmhessel commented 1 year ago

Awesome! Thanks for your help on this. I'll close this issue because it seems like the mystery has been solved and no updates are required. CLIP does do the resizing ( I think vit-L is maybe bigger than 224). But a round of jpg compression before the resizing could still cause this.

We resize before because you'd be surprised what you find in 1.4B web images... Enormous gigabyte sized images, jpg compression bombs, etc.