huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.32k stars 27.08k forks source link

Add optimized `PixtralImageProcessorFast` #34836

Open mgoin opened 1 day ago

mgoin commented 1 day ago

What does this PR do?

This PR implements a fast image processor for Pixtral. Follows issue https://github.com/huggingface/transformers/issues/33810.

The key acceleration comes from replacing Pillow/Numpy tensors and functions (resize, rescale, normalize) with torch tensors and torchvisionv2 functions. It comes along with support for torch.compile and passing device="cuda" during inference to process the input on GPU. One limitation is that only return_tensors="pt" will be supported.

Usage

from transformers import AutoImageProcessor

slow_processor = AutoImageProcessor.from_pretrained("mistral-community/pixtral-12b", use_fast=False)
fast_processor = AutoImageProcessor.from_pretrained("mistral-community/pixtral-12b", use_fast=True)
compiled_processor = torch.compile(fast_processor, mode="reduce-overhead")

From simple benchmarking with a single image of size [3, 876, 1300], I see 6x to 10x speedup

--------------------------------------------------
Slow Processor (PIL Image) Statistics (milliseconds):
          Mean: 23.680
        Median: 23.098
       Std Dev: 2.240
           Min: 21.824
           Max: 36.064

--------------------------------------------------
Fast Processor (PIL Image) Statistics (milliseconds):
          Mean: 3.759
        Median: 3.762
       Std Dev: 0.133
           Min: 3.556
           Max: 4.223

--------------------------------------------------
Compiled Processor (PIL Image) Statistics (milliseconds):
          Mean: 4.632
        Median: 4.794
       Std Dev: 1.086
           Min: 3.488
           Max: 11.707

--------------------------------------------------
Slow Processor (Torch Image) Statistics (milliseconds):
          Mean: 22.331
        Median: 21.878
       Std Dev: 1.821
           Min: 21.316
           Max: 36.603

--------------------------------------------------
Fast Processor (Torch Image) Statistics (milliseconds):
          Mean: 2.242
        Median: 2.209
       Std Dev: 0.164
           Min: 2.182
           Max: 3.803

--------------------------------------------------
Compiled Processor (Torch Image) Statistics (milliseconds):
          Mean: 2.125
        Median: 2.117
       Std Dev: 0.073
           Min: 2.062
           Max: 2.594

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

qubvel commented 1 day ago

Hi @mgoin! Sounds great! Thanks for working on this 🤗

cc @yonigozlan maybe if you have bandwidth

mgoin commented 1 day ago

Thanks for the review and context @yonigozlan ! I will look into it later today. Yes you are correct about using it within a Processor, however I have tested this works within vLLM simply by adding use_fast=True to our AutoProcessor.from_pretrained() call here. No need to manually specify the Processor class.

One bug I noticed is that if I specify use_fast=True and there isn't a Fast version of the ImageProcessor available, I get an exception. I can look into this, but would be good to get clarity that this is unintended behavior.

yonigozlan commented 1 day ago

Oh great news that it already works with AutoProcessor. As I said this is the first fast image processor used in a processor so it was not guaranteed :).

One bug I noticed is that if I specify use_fast=True and there isn't a Fast version of the ImageProcessor available, I get an exception. I can look into this, but would be good to get clarity that this is unintended behavior.

Yes this is the same right now when using ImageProcessingAuto. I don't think it should be that way though, especially as more and more people will want to use fast image processors by default. I'll open a PR to fix this.

Current plan is:

The deprecation cycle is needed as there are slight differences in outputs when using torchvision vs PIL, see this PR https://github.com/huggingface/transformers/pull/34785 for more info.