kookmin-sw / capstone-2024-08

아나운서 준비생을 위한 맞춤형 AI 스피치 연습 애플리케이션, Loro(로로)
https://kookmin-sw.github.io/capstone-2024-08
5 stars 4 forks source link

Use Whisper-Jax for speed #73

Open why-arong opened 3 months ago

why-arong commented 3 months ago

Installation

Whisper JAX was tested using Python 3.9 and JAX version 0.4.5.

Installation assumes that you already have the latest version of the JAX package installed on your device. You can do so using the official JAX installation guide:

https://github.com/google/jax#installation

Check GPU Hardware:

lspci | grep -i nvidia

Check System Architecture:

uname -m

→ we use g4dn instance.

So, run the following command

pip install -U "jax[cuda12]"

Once the appropriate version of JAX has been installed, Whisper JAX can be installed through pip:

pip install git+https://github.com/sanchit-gandhi/whisper-jax.git

All Whisper models on the Hugging Face Hub with Flax weights are compatible with Whisper JAX

https://huggingface.co/openai

We will use whisper-medium model

and

Half-Precision for computation speed.

Putting it all together

from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

# instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-medium", dtype=jnp.bfloat16)

text = pipeline("audio.mp3", task="transcribe")

Whisper JAX makes use of JAX's [pmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices.

This function is Just In Time (JIT) compiled the first time it is called. Thereafter, the function will be cached, enabling it to be run in super-fast time.

→ It means after forward call, it will be super fast!

Additionally, it provides batching and Timestamps, but I won't use it for now