Closed mr-sarthakgupta closed 1 week ago
Hi there 👋 this is possible, but you need to export the model (to ONNX) with those parameters defined. Do you know which model you want to run?
I'd like to run these models:
Hi @xenova please let me know how I could begin. Also, would it require changes for each model separately or could be done with a single change?
This resource should be able to help out: https://huggingface.co/docs/optimum/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models. Let me know if you have any questions! Since you're only focused on CLIP, it would just require a single change.
Here's the original CLIP ONNX config which you can use as inspiration: https://github.com/huggingface/optimum/blob/db51410ae5ef4cbde7518cf01a997239dffbde1d/optimum/exporters/onnx/model_configs.py#L889-L908 (all you probably need to do is add a value for the hidden layers.
Hi, thanks for the resources, but I can't seem to understand how code is functioning. On the surface it seems we're only passing dicts with dimension sizes. Please correct me if I'm wrong and let me know where I could read up and better understand what's happening.
Also: Would changing this require changing the way we load the model in transformersjs? Or we'd still be able to use something like CLIPModel.from_pretrained?
On the surface it seems we're only passing dicts with dimension sizes.
That's pretty much it - yes! torch.onnx.export
does the tracing for us and generates the graph.
Also: Would changing this require changing the way we load the model in transformersjs? Or we'd still be able to use something like CLIPModel.from_pretrained?
Nope, no change to usage! :)
If you have example python code for what you'd like to achieve, I can help create the equivalent JS code. Do you have a link to a tutorial/blog or something?
Something like this should work:
pip install --upgrade 'onnx==1.13.1' 'onnxruntime<1.16.0' 'optimum==1.20.0'
from optimum.exporters.onnx.model_configs import CLIPTextOnnxConfig, ViTOnnxConfig
from typing import Dict
class CLIPTextModelWithProjectionOnnxConfig(CLIPTextOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
outputs = {"text_embeds": {0: "batch_size"}}
for i in range(self._normalized_config.num_layers + 1):
outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}
return outputs
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
if framework == "pt":
import torch
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int64)
return dummy_inputs
class CLIPVisionModelWithProjectionOnnxConfig(ViTOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
outputs = {"image_embeds": {0: "batch_size"}}
for i in range(self._normalized_config.num_hidden_layers + 1):
outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}
return outputs
from optimum.exporters.onnx import export_models
from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection
model_id = "openai/clip-vit-base-patch32"
text_model = CLIPTextModelWithProjection.from_pretrained(model_id)
vision_model = CLIPVisionModelWithProjection.from_pretrained(model_id)
output_dir='custom'
export_models(
models_and_onnx_configs={
"text_model": (text_model, CLIPTextModelWithProjectionOnnxConfig(text_model.config)),
"vision_model": (vision_model, CLIPVisionModelWithProjectionOnnxConfig(vision_model.config)),
},
output_dir=output_dir,
model_kwargs={
'output_hidden_states': True
}
)
# Move to onnx subfolder
import os
onnx_path = os.path.join(output_dir, 'onnx')
os.makedirs(onnx_path, exist_ok=True)
os.rename(os.path.join(output_dir, 'text_model.onnx'), os.path.join(onnx_path, 'text_model.onnx'))
os.rename(os.path.join(output_dir, 'vision_model.onnx'), os.path.join(onnx_path, 'vision_model.onnx'))
# Also save other files
from transformers import AutoConfig, AutoTokenizer, AutoProcessor
AutoConfig.from_pretrained(model_id).save_pretrained(output_dir)
AutoTokenizer.from_pretrained(model_id).save_pretrained(output_dir)
AutoProcessor.from_pretrained(model_id).save_pretrained(output_dir)
from huggingface_hub import login
login()
followed by
from huggingface_hub import HfApi, create_repo
api = HfApi()
repo_id="YOUR_REPO_ID"
create_repo(repo_id, exist_ok=True)
api.upload_folder(folder_path=output_dir, repo_id=repo_id)
I've uploaded this demo to https://huggingface.co/onnx-community/clip-vit-base-patch32_hidden-states
NOTE: Since we haven't made any quantized versions, you need to specify { quantized: false }
as the second parameter in .from_pretrained
.
Example: Compute text embeddings and hidden states with CLIPTextModelWithProjection
.
import { AutoTokenizer, CLIPTextModelWithProjection } from '@xenova/transformers';
// Load tokenizer and text model
const tokenizer = await AutoTokenizer.from_pretrained('onnx-community/clip-vit-base-patch32_hidden-states');
const text_model = await CLIPTextModelWithProjection.from_pretrained('onnx-community/clip-vit-base-patch32_hidden-states', { quantized: false });
// Run tokenization
const texts = ['a photo of a car', 'a photo of a football match'];
const text_inputs = tokenizer(texts, { padding: true, truncation: true });
// Compute embeddings and hidden states
const { text_embeds, ...text_hidden_states } = await text_model(text_inputs);
Example: Compute vision embeddings and hidden states with CLIPVisionModelWithProjection
.
import { AutoProcessor, CLIPVisionModelWithProjection, RawImage } from '@xenova/transformers';
// Load processor and vision model
const processor = await AutoProcessor.from_pretrained('onnx-community/clip-vit-base-patch32_hidden-states');
const vision_model = await CLIPVisionModelWithProjection.from_pretrained('onnx-community/clip-vit-base-patch32_hidden-states', { quantized: false });
// Read image and run processor
const image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg');
const image_inputs = await processor(image);
// Compute embeddings and hidden states
const { image_embeds, ...vision_hidden_states } = await vision_model(image_inputs);
text_hidden_states
and vision_hidden_states
are now objects containing the hidden states per layer.
Hope that helps!
That's really really helpful!! Thanks @xenova, you're the best!
Hi @xenova thanks for all the help, but lately I've noticed that the hidden values I obtain from transformersjs are different from the ones obtained from transformersjs, I've been trying to get the graph of computation to find if there's any difference but to no avail
This is typically due to minor pre-processing differences in images (sharp in js and pillow in py). How large is the difference?
comparing the L2 distance is ~350 between the result obtained from python and from transformersjs, given that the values are mostly in the range [-1.2, 1.2] with vector shape [257, 1024], this doesn't seem all that significant but I'm concerned that it would lead to degrading results.
I also checked the input, it has an L2 distance of ~78 for size [3, 224, 224] with most values in [-1, 1]
Feature request
In transformers library, we could pass the argument
output_hidden_states = True
to receive the activations for all the hidden layers of the model.Motivation
This is a useful feature for tasks which might use pretrained models and use their hidden activations to read features and do certain tasks.
Your contribution
I don't have a lot of depth in this repository at the moment but would love to learn more and contribute this feature. Any resources/suggestions where I should begin are really appreciated!