sanchit-gandhi / whisper-jax

JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU.
Apache License 2.0
4.33k stars 366 forks source link

Model is not performing fast on AWS GPU TensorFlow 2.12.0 (Ubuntu 20.04) #57

Open pradipBchoudhari opened 1 year ago

pradipBchoudhari commented 1 year ago

Hi Team, You did great job, appreciate your work. Kindly help me to know why I am not getting expected performance for transcription, in my case for 50 seconds of audio file it takes 224-260 seconds. System details: OS: Ubuntu 20.04.6 LTS (GNU/Linux 5.15.0-1031-aws x86_64) AWS Deep Learning AMI GPU TensorFlow 2.12.0 NVIDIA driver version: 525.85.12 CUDA version: 11.8 Python: 3.10.10

Setup: followed below steps // PIP install/upgrade pip install --upgrade pip // CUDA 12 installation pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html // CUDA 11 installation pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html // Install 'whisper-jax' pip install git+https://github.com/sanchit-gandhi/whisper-jax.git // Install ffmpeg sudo apt install ffmpeg -y

Also before executing code checked devices.

import jax jax.devices() [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

Executing below code:

from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
import time

# instantiate pipeline with bfloat16 and enable batching
pipeline = FlaxWhisperPipline("openai/whisper-medium", dtype=jnp.bfloat16, batch_size=8)

print("Compilation phase - Transcribe start ...")
t = time.time()
# transcribe and return timestamps
outputs = pipeline("p1.wav",  task="transcribe", return_timestamps=True)
print("... Transcribe end");
print(time.time() - t)
print(outputs)

print("Process phase - Transcribe start ...")
t = time.time()
# transcribe and return timestamps
outputs = pipeline("p1.wav",  task="transcribe", return_timestamps=True)
print("... Transcribe end");
print(time.time() - t)
print(outputs)

Output: (note: From output snippet removed transcribed text output) 2023-04-28 13:00:35.470100: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT Downloading (…)rocessor_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████| 185k/185k [00:00<00:00, 180MB/s] Downloading (…)okenizer_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████| 843/843 [00:00<00:00, 8.56MB/s] Downloading (…)olve/main/vocab.json: 100%|███████████████████████████████████████████████████████████████████████████████| 1.04M/1.04M [00:00<00:00, 375MB/s] Downloading (…)/main/tokenizer.json: 100%|███████████████████████████████████████████████████████████████████████████████| 2.20M/2.20M [00:00<00:00, 226MB/s] Downloading (…)olve/main/merges.txt: 100%|█████████████████████████████████████████████████████████████████████████████████| 494k/494k [00:00<00:00, 108MB/s] Downloading (…)main/normalizer.json: 100%|███████████████████████████████████████████████████████████████████████████████| 52.7k/52.7k [00:00<00:00, 233MB/s] Downloading (…)in/added_tokens.json: 100%|██████████████████████████████████████████████████████████████████████████████| 2.08k/2.08k [00:00<00:00, 22.4MB/s] Downloading (…)cial_tokens_map.json: 100%|██████████████████████████████████████████████████████████████████████████████| 2.08k/2.08k [00:00<00:00, 22.6MB/s] Downloading (…)lve/main/config.json: 100%|██████████████████████████████████████████████████████████████████████████████| 1.99k/1.99k [00:00<00:00, 18.6MB/s] Downloading flax_model.msgpack: 100%|████████████████████████████████████████████████████████████████████████████████████| 3.06G/3.06G [00:07<00:00, 434MB/s] Downloading (…)neration_config.json: 100%|██████████████████████████████████████████████████████████████████████████████| 3.48k/3.48k [00:00<00:00, 30.5MB/s] Compilation phase - Transcribe start ... There was an error while processing timestamps, we haven't found a timestamp as last token. Was WhisperTimeStampLogitsProcessor used? ... Transcribe end 260.81376910209656

Process phase - Transcribe start ... There was an error while processing timestamps, we haven't found a timestamp as last token. Was WhisperTimeStampLogitsProcessor used? ... Transcribe end 224.0671808719635

Thanking you Pradip

pradipBchoudhari commented 1 year ago

After correcting 'dtype' and setting 'batch_size' to 1, on GPU got better performance. pipeline = FlaxWhisperPipline("openai/whisper-medium", dtype=jnp.float16, batch_size=1)

Input:- 50 seconds audio file Output:- 2023-05-15 09:43:43.061333: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT Compilation phase - Transcribe start ... 2023-05-15 09:44:09.727473: W external/xla/xla/service/gpu/ir_emitter_triton.cc:761] Shared memory size limit exceeded. There was an error while processing timestamps, we haven't found a timestamp as last token. Was WhisperTimeStampLogitsProcessor used? ... Transcribe end 31.795202493667603

Process phase - Transcribe start ... There was an error while processing timestamps, we haven't found a timestamp as last token. Was WhisperTimeStampLogitsProcessor used? ... Transcribe end 4.632764577865601