intel / intel-extension-for-openxla

Apache License 2.0
36 stars 10 forks source link

should I use openxla or intel extension for transformers to load Whisper Jax Model? #36

Open sleepingcat4 opened 1 month ago

sleepingcat4 commented 1 month ago

I wanted to load Whisper Jax on Intel datacentre GPU Max series. I was wondering should I use Intel OpenXLA extension or Intel extension for transformers? not sure if openxla supports quantisation by default or not

Zantares commented 2 weeks ago

I suggest starting with this OpenXLA Extension first because Transformer Extension should not take care of JAX model before. OpenXLA Extension has simply supported naive quantization feature based on Keras 3, but not verified on all scenarios.

sleepingcat4 commented 2 weeks ago

@Zantares thanks! At the moment, I opted to use HPU (Gaudi2), Can OpenXLA able to compile JAX model on Gaudi2?

Can you suggest sample code that loads the JAX model in 4bit Quantaisation and starts loading the model?

Zantares commented 2 weeks ago

@Zantares thanks! At the moment, I opted to use HPU (Gaudi2), Can OpenXLA able to compile JAX model on Gaudi2?

Can you suggest sample code that loads the JAX model in 4bit Quantaisation and starts loading the model?

The support for Gaudi is still under development because it uses different low-level software stack... We have added a simple FP8 example in our repo: https://github.com/intel/intel-extension-for-openxla/tree/main/example/fp8, but we didn't verify INT4 yet. What's the INT4 model you are looking for? Maybe we can check it by ourselves first.

sleepingcat4 commented 2 weeks ago

@Zantares Whisper-Jax model. I'm working with Intel actually (Intel AI labs). And one of the former Intel employee who is on our team suggested on Intel-gaudi2 Intel openXLA library won't provide advantage since it already uses JIT.

What're your thoughts?

Zantares commented 2 weeks ago

I may not provide many suggestions on Gaudi because it's not ready... But we have verified that JAX Whisper models (from Transformer example: https://github.com/huggingface/transformers/tree/main/examples/flax/speech-recognition) can be run on Intel GPUs (Data Center Max/Flex).

For GPU, OpenXLA can provide some generic optimizations and make applications run faster. For Gaudi, as I know it uses different low-level software stack and may not provide many advantages.

sleepingcat4 commented 1 week ago

@Zantares that's wonderful insight. If I'm not mistaken OpenXLA can provide an edge for Intel GPUs right? And Gaudi stack to take advantage of it isn't done yet.

I had an feature request: How about OpenXLA library be integrated inside Intel_extension_for_transformers library. Since Intel extension for transformers aim to standardise HF models inference and loading on Intel hardware. I have already ran HF models with 4bit using Intel extension for transformers and it was quite a breeze.

Can you guys integrate OpenXLA library components so that we can run JAX models on Intel GPUs and maybe Gaudi too, using intel extension for transformers. (For inference only)

Because I don't think for generic optimizations loading another Intel library is an overkill.

Zantares commented 1 week ago

This is more like a Intel_extension_for_transformers request but not OpenXLA. Since Intel_extension_for_transformers is a third-party independent modular, there's no much work from OpenXLA side. Maybe you can raise it to Intel_extension_for_transformers community if you have strong request of JAX model. I can't determine it because the current main direction is PyTorch, so Intel_extension_for_transformers will serve PyTorch first.

Same story as Gaudi, support PyTorch is the 1st priority, then JAX/OpenXLA. You can see that even in PyTorch side the supporting of Gaudi is still WIP, that's what I mentioned in previous comments: Gaudi is totally a different thing and can't leverage current works.