Open melihogutcen opened 1 year ago
It took forever to get flax/jax/cuda installed correctly on this ancient hardware. But FWIW it's not faster on this old brick either. It does take most load off CPU, however... Confirming with Quadro M3000M, Compute Capability 5.2
Try the TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=''
stanza and see if it tells you anything new.
$ TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS='' python
Python 3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from whisper_jax import FlaxWhisperPipline
2023-04-25 03:26:34.596566: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
>>> import jax.numpy as jnp
>>>
>>> pipeline = FlaxWhisperPipline("openai/whisper-small.en", dtype=jnp.bfloat16, batch_size=16)
2023-04-25 03:26:57.709620: I external/xla/xla/service/service.cc:168] XLA service 0xad4fed0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-04-25 03:26:57.709677: I external/xla/xla/service/service.cc:176] StreamExecutor device (0): Interpreter, <undefined>
2023-04-25 03:26:57.713812: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:218] TfrtCpuClient created.
2023-04-25 03:26:57.779957: I external/xla/xla/stream_executor/cuda/cuda_gpu_executor.cc:997] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-04-25 03:26:57.780275: I external/xla/xla/service/service.cc:168] XLA service 0xacc8c40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-04-25 03:26:57.780332: I external/xla/xla/service/service.cc:176] StreamExecutor device (0): Quadro M3000M, Compute Capability 5.2
2023-04-25 03:26:57.781066: I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:198] Using BFC allocator.
2023-04-25 03:26:57.781166: I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 3175464960 bytes on device 0 for BFCAllocator.
2023-04-25 03:26:57.807468: I external/xla/xla/pjrt/pjrt_api.cc:86] GetPjrtApi was found for tpu at /home/k/.local/lib/python3.10/site-packages/libtpu/libtpu.so
2023-04-25 03:26:57.807502: I external/xla/xla/pjrt/pjrt_api.cc:58] PJRT_Api is set for device type tpu
2023-04-25 03:26:59.436941: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8900
2023-04-25 03:26:59.627354: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking
@themanyone I tried your script with the whisper-large-v2 model, and the output is like below.
2023-04-25 14:58:25.859863: I external/xla/xla/service/service.cc:168] XLA service 0xa2ff910 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-04-25 14:58:25.859891: I external/xla/xla/service/service.cc:176] StreamExecutor device (0): Interpreter, <undefined>
2023-04-25 14:58:25.868946: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:218] TfrtCpuClient created.
2023-04-25 14:58:26.411156: I external/xla/xla/service/service.cc:168] XLA service 0xa072e20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-04-25 14:58:26.411209: I external/xla/xla/service/service.cc:176] StreamExecutor device (0): NVIDIA GeForce RTX 3090 Ti, Compute Capability 8.6
2023-04-25 14:58:26.411804: I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:198] Using BFC allocator.
2023-04-25 14:58:26.411887: I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 19069648896 bytes on device 0 for BFCAllocator.
2023-04-25 14:58:32.620089: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8500
2023-04-25 14:58:32.688973: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking
I spoke too soon perhaps. I was not able to get speedup with dtype=jnp.bfloat16, batch_size=16 HOWEVER When I used the example exactly as given in the README... and without debugging... it works!!! Decoding is sped up to 1.3 seconds the 2nd time pipeline is called.
JAX_PLATFORMS='' python
from whisper_jax import FlaxWhisperPipline
pipeline = FlaxWhisperPipline("openai/whisper-small.en")
import os
t = time.time(); text = pipeline("test.mp3"); print(time.time() - t)
15.293785810470581
t = time.time(); text = pipeline("test.mp3"); print(time.time() - t)
1.2636265754699707
Here, I used recommended parameters (https://huggingface.co/blog/asr-chunking https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor?usp=sharing#scrollTo=Mh_e6rV62QUM) @themanyone
In these two codes parameters, sound data, and environments are the same. Still, there is no acceleration. @sanchit-gandhi
Hey @melihogutcen - if I understand correctly, the JAX code you're running only does one transcription step (based on what you've shared here: https://github.com/sanchit-gandhi/whisper-jax/issues/44#issue-1682692818)? If this is the case, this first transcription step is our compilation step, which we expect to be slow.
If you do a second transcription step, you'll find that Whisper JAX should be extremely fast:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
import time
# instantiate pipeline with float16 and enable batching
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.float16, batch_size=8)
# transcribe and return timestamps - compilation step will be slow
start = time.time()
outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
runtime = time.time() - start
print("Compilation: ", runtime)
# transcribe again - use cached function, will be fast
start = time.time()
outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
runtime = time.time() - start
print("Cached: ", runtime)
You can read more about just-in time (JIT) compilation here: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html#jit-compiling-a-function
Yes, in the first transcription, it is highly slower, but in the second time yes it is much better than the first transcription but still, it is not faster than WhisperTransformers.
For example durations as below:
Duration of the first transcription with WhisperJAX: 60.4s
Duration of the second transcription with WhisperJAX: 20.1s
Duration of the WhisperTransformers: 14.23s
so this implementation only serves use cases where you intend to transcribe the same audio file more than once?
so this implementation only serves use cases where you intend to transcribe the same audio file more than once?
Why would you transcribe the same audio file twice? I can get 1000000x performance by caching the text...
No, you can transcribe any audio file with this method. You have to run it slowly once (compile). After doing this, you can run it fast on any audio file after that (cached). See the JAX JIT docs for details on JIT: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html
And the Kaggle Notebook for an application to Whisper JAX: https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu
Is it possible to pre-compile rather than passing through an audio file first?
Yep - you can just pass a dummy log-mel spectrogram, see: https://github.com/sanchit-gandhi/whisper-jax/blob/e2dddc98f8af5e5522341f61d5647cb31cd33ec7/app/app.py#L86-L87
After that you can call:
outputs = pipeline("audio.mp3", return_timestamps=True)
@melihogutcen I have the same issue, I try the original implementation and this one, and the original seems to be faster each time (I am comparing the second time runs). Were you able to get the fast results?
This issue WAS happening, before totally wiping out the system and upgrading fresh on a new hard drive. This time installing updated video drivers, cuda, and cudnn from the nvidia website, instead of the distro-packaged versions. Now it's super-fast.
Hello Guys, so I'm in the same shoe. Here is my code:
import time
from whisper_jax import FlaxWhisperPipline
pipeline = FlaxWhisperPipline("openai/whisper-large-v2")
t = time.time(); text = pipeline("rec11.mp3"); print(time.time() - t)
t = time.time(); text = pipeline("rec11.mp3"); print(time.time() - t)
first run is: 62 sec last run is: 38 sec
audio file duration: 00:02:22.87 (from ffmpeg)
output of my nvidia-smi:
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 3090 On | 00000000:01:00.0 Off | N/A |
| 0% 45C P8 25W / 350W| 19175MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
/usr/local/cuda/bin/nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
i was running python with TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=''
got these messages
I external/xla/xla/service/service.cc:168] XLA service 0x560a412f7350 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
I external/xla/xla/service/service.cc:176] StreamExecutor device (0): Interpreter, <undefined>
I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:433] TfrtCpuClient created.
I external/xla/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I external/xla/xla/service/service.cc:168] XLA service 0x560a3fe5d650 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I external/xla/xla/service/service.cc:176] StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:521] Using BFC allocator.
I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 19065569280 bytes on device 0 for BFCAllocator.
I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:434] Loaded cuDNN version 8700
I external/tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
I external/tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking
I external/xla/xla/stream_executor/gpu/asm_compiler.cc:328] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_1494', 68 bytes spill stores, 68 bytes spill loads
I external/xla/xla/stream_executor/gpu/asm_compiler.cc:328] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_1494', 128 bytes spill stores, 128 bytes spill loads
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 45: 346.347 vs 387.222
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 70: 350.232 vs 389.543
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 299: 350.822 vs 394.675
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 407: 345.434 vs 394.303
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 533: 352.764 vs 396.02
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 553: 347.322 vs 391.286
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 593: 348.782 vs 393.536
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 694: 343.945 vs 390.099
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 721: 355.291 vs 395.885
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 797: 347.525 vs 390.146
E external/xla/xla/service/gpu/triton_autotuner.cc:377] Results mismatch between different tilings. This is likely a bug/unexpected loss of precision.
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 45: 346.347 vs 387.222
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 70: 350.232 vs 389.543
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 299: 350.822 vs 394.675
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 407: 345.434 vs 394.303
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 533: 352.764 vs 396.02
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 553: 347.322 vs 391.286
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 593: 348.782 vs 393.536
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 694: 343.945 vs 390.099
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 721: 355.291 vs 395.885
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 797: 347.525 vs 390.146
E external/xla/xla/service/gpu/triton_autotuner.cc:377] Results mismatch between different tilings. This is likely a bug/unexpected loss of precision.
i would appreciate any direction, how to find the bottleneck. faster-whisper gets the this job done much faster, so i'm guessing something is not okay with my setup.
thank you.
I could not install torch in the same venv as whisper-jax when making my open-source Whisper Dictation app here. Doing so would downgrade nvidia-cudnn-cu11 to a non-working version that would mostly use the CPU. Then I'd have to run pip install --upgrade nvidia-cudnn-cu11
to get it back. So I put it in a venv
to keep them separate. I have torch cuda python3.11 running in my main python install. And nvidia-cudnn-cu11, jax[cuda] python3.10 et al. in the virtual environ.
Hi, I couldn't get faster results. Whisper transformers are faster than Jax implementation.
SystemInfo
jax ==0.4.8 jaxlib==0.4.7+cuda11.cudnn82 transformers==4.28.1 CUDA Version: 11.7 Python 3.9.16 GPU: RTX 3090 Ti
Transformers Implementation:
JAX Implementation:
here I tried 3-4 times but I couldn't decrease the computation time.