xenova / transformers.js

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!
https://huggingface.co/docs/transformers.js
Apache License 2.0
9.87k stars 582 forks source link

Add class for CLIPVisionModel #816

Open mr-sarthakgupta opened 1 week ago

mr-sarthakgupta commented 1 week ago

Model description

The transformersjs equivalent of https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/clip#transformers.CLIPVisionModel

Prerequisites

Additional information

We could use optimum to export the entire CLIP model to ONNX and then use transformersjs to use the CLIPVisionModel from the exported CLIP model. Since CLIPVisionModelWithProjection is already in place I believe we could use it's class to obtain the CLIPVisionModel as well.

Your contribution

Given access to relevant resources, I'd be more than happy to contribute to this repo by creating a class for this model type!

mr-sarthakgupta commented 1 week ago

I suppose we have to make a class very similar to this: https://github.com/xenova/transformers.js/blob/30720089c6215f71b519c17369404cdfd14c32e0/src/models.js#L3164, but what other changes do we need to make in order to load the equivalent of CLIPVisionModel instead of CLIPVisionModelWithProjection

@xenova please help

xenova commented 1 week ago

Hey there! 👋 That's right, you need to do two things:

  1. Modify/create a custom ONNX export like this
  2. Define a new class called CLIPVisionModel like the code you linked to.

If you have example python code for running the models, feel free to post here and I can show you the equivalent JS code 👍

mr-sarthakgupta commented 1 week ago

Thanks for the help @xenova! I think the ONNX export class for this model would look like:

class CLIPVisionOnnxConfig(ViTOnnxConfig):
    @property
    def outputs(self) -> Dict[str, Dict[int, str]]:
        outputs = {"pooler_output": {0: "batch_size"},}
        for i in range(self._normalized_config.num_hidden_layers + 1):
            outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}
        return outputs

And I suppose that the CLIPVisionModelclass would be identical to the CLIPVisionModelWithProjection class except for the name.

And here's the python code we need to run the models:


import torch
from transformers import CLIPVisionModel

model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16")

outputs = model(pixel_values=torch.randn(1, 3, 224, 224))

print(outputs.pooler_output.shape)
mr-sarthakgupta commented 5 days ago

Hi @xenova, please let me know if the following looks good, I could open a PR if the following code is alright:

Code for transformersjs/scripts/extra/clip.py:

class CLIPVisionModelOnnxConfig(CLIPVisionOnnxConfig):
    @property
    def outputs(self) -> Dict[str, Dict[int, str]]:
        outputs = {"pooler_output": {0: "batch_size"},}
        for i in range(self._normalized_config.num_hidden_layers + 1):
            outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}
        return outputs

for transformersjs/sc/models.js:

export class CLIPVisionModel extends CLIPPreTrainedModel {
    /** @type {PreTrainedModel.from_pretrained} */
    static async from_pretrained(pretrained_model_name_or_path, options = {}) {
        // Update default model file name if not provided
        options.model_file_name ??= 'vision_model';
        return super.from_pretrained(pretrained_model_name_or_path, options);
    }
}
mr-sarthakgupta commented 2 days ago

Hi @xenova, thanks a lot for building such a fantastic repository and I truly appreciate all the hard work you've put in! I understand that it must be difficult to stay on track with such an active repository. I would humbly like to request you to help me understand whether I'd be good to go with the code above.

I understand that it's not a really good thing to ask for a quick response after what all you've already built but I require this feature for a project the deadline of which is just around the corner, so it would be very kind of you if you could help me get there speedily. Thanks again.

xenova commented 2 days ago

Hi again 👋 Sure that looks like it will work - have you tried exporting a model with that config? The usage will be similar to https://github.com/xenova/transformers.js/issues/799#issuecomment-2171410970.

To use from javascript, you can literally just import CLIPPreTrainedModel and extend the class yourself in your code. That said, I might as well add the default configs to the library, since they are defined in transformers (python) already.

mr-sarthakgupta commented 1 day ago

Hi again, thanks for your response! I was able to run the model by extending the class. However, I'm observing large deviation in output values when I run the model in javascript. After porting the model to onnx, I tested the model outputs, the fp32 and fp16 models are giving an average deviation of ~10^-5, which is very good but the quantized model is having a very large deviation, which is expected. However, when I tried to run the model in javascript, it is giving the same deviated output even when I explicitly add quantized: false in the parameters. I am attaching the code as:


import { AutoProcessor, ViTModel, RawImage, CLIPPreTrainedModel, PreTrainedModel } from '@xenova/transformers';
import fs from 'fs';

export class CLIPVisionModel extends CLIPPreTrainedModel {
    /** @type {PreTrainedModel.from_pretrained} */
    static async from_pretrained(pretrained_model_name_or_path, options = {}) {
        // Update default model file name if not provided
        options.model_file_name ??= 'vision_model';
        return super.from_pretrained(pretrained_model_name_or_path, options);
    }
}

const processor = await AutoProcessor.from_pretrained('mrsarthakgupta/onnx_test_repo_3');
const vision_model = await CLIPVisionModel.from_pretrained('mrsarthakgupta/onnx_test_repo_3', {
    device: 'webgpu',
    dtype: 'fp32',
    quantized: false
},
);

// Read image and run processor
const image = await RawImage.read('F-VF8LyaEAAF0s9.jpg');
const time_start = performance.now();
const image_inputs = await processor(image);

const outs = await vision_model(image_inputs);
console.log(outs);

const jsonString = JSON.stringify(outs.pooler_output.data, null, 2);

fs.writeFile('pooleroutjs.json', jsonString, (err) => {
    if (err) {
        console.error('Error writing file:', err);
    } else {
        console.log('File has been written successfully');
    }
});

Edit: I've verified that the processed to the python and js models are almost identical and couldn't justify this difference at all. In fact, I'm using the same preprocessor_config.json file for preprocessing the input for the transformersjs model, the onnxruntime session and the python transformers model. While the outputs of python transformers and onnxruntime agree, outputs from transformersjs don't

Tried it with transformersjs#3 and with 2.17.1. specifying the model name to the fp32 model(vision_model.onnx) and quantized: false and using device: 'cpu' and still it repeatedly keeps giving the same result. I'm using node to run this file.