Open dinskuns opened 1 year ago
I was able to make it work. You don't get speedup on the first transcription, which is slower due to compile. You get better performance on subsequent calls to the function. The difference is more like 10x on my old GPU, a 4Gb model: Nvidia Quadro 3000M, rather than the TPU they are comparing with.
It also depends on the video duration, since you basically gain parallelization of transcription of 30sec chunks, so if you take an hour-long video with powerful GPU or TPU you should see it, but... Also please be aware that JAX produces inaccurate transcription compared to the original OpenAI with a lot of repeating of the words and gibberish something to do with the hallucination issue...
JAX produces inaccurate transcription compared to the original OpenAI with a lot of repeating of the words and gibberish something to do with the hallucination issue...
We don't get that problem at all, even on the tiny model. It happens with other language models when using the incorrect float or int type. Specifically, when trying to convert the model to float for my crappy video card. Which it hasn't been optimized for. Using regular float16 or float32 works best for the old hardware.
Or, another topic, for a somewhat-usable server implementation handling multiple clients, we used jnp.bfloat16, which only seems to help with batch_size > 1. The batch_size doesn't do anything to speed up single-user performance and only slows things down though. Again, recognizing that it would work better with TPU. Just testing.
I was not able to get 70x speed at all, quite the opposite, things slow down for 30s audios