mlfoundations / open_clip

An open source implementation of CLIP.
Other
9.93k stars 959 forks source link

Enable passing `output_hidden_states` #731

Open thepowerfuldeez opened 11 months ago

thepowerfuldeez commented 11 months ago

Related to #657 Inspired by PR above, I made PR without breaking backward compatibility. In addition, I made support for passing output_hidden_states as attribute for VisionTransformer and TextTransformer classes.

Example:

import torch
import open_clip
model_type = 'hf-hub:UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B'

model, _, preprocess = open_clip.create_model_and_transforms(
    model_type,
)
tokenizer = get_tokenizer(model_type)

image = preprocess(Image.open(PATH_TO_IMAGE)).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])

with torch.no_grad():
    image_result = model.encode_image(image, output_hidden_states=True)
    text_result = model.encode_text(text, output_hidden_states=True)

    image_features, image_hidden_states = image_result
    text_features, text_hidden_states = text_result
rwightman commented 11 months ago

@thepowerfuldeez thanks for the PR, will say that we need to do this one carefully as it impacts the output interface. I recognize that people want this, but it's been slow to be added because it's a bit of a mess when you consider all the details.

First things first, I feel we should only allow this if dictionary output is enabled, having too many tuple variations as possible outputs is asking for trouble.

Next, the internal typing has gotchas with torchscript when you alternate between Tuple and tensor outputs. Not quite sure what the needed combination of typing would be to have that pass.

thepowerfuldeez commented 11 months ago

Hi @rwightman ! I’m on the same page with you, are it should be supported as a dict output. I couldn’t decide how to better use it considering dict output appears only in 1 place, where I needed to have output of VisionTransformer to output hidden states (I am not using CLIP class and hence implemented logic with setting attribute for transformer classes). What is the better way to move such logic as a dict output here?

alvaro-stylesage commented 6 months ago

Hi @thepowerfuldeez , thanks a lot for this PR, it has been really useful. However, I have some doubts when using the image_hidden_states as embedding for downstream classification tasks. I am doing:

last_hidden_state = image_hidden_states[-1].numpy()
cls_embedding = last_hidden_state[:, 0, :]

But then, using that (1024,) sized CLS embedding is not resulting in good classification metrics for my task (~50% accuracy), while with the image_features sized (768,) I am getting ~85% accuracy. Can you think of an explanation of this? Is the last hidden state being taken before the LayerNorm and this might affect?

Thanks!