argmaxinc / WhisperKit

On-device Speech Recognition for Apple Silicon
https://takeargmax.com/blog/whisperkit
MIT License
3.17k stars 268 forks source link

Added TimestampRulesFilter implementation #45

Closed jkrukowski closed 6 months ago

jkrukowski commented 6 months ago

This PR adds implementation for TimestampRulesFilter. The implementation is based on https://github.com/openai/whisper/blob/master/whisper/decoding.py#L441

Couple of questions here @ZachNagengast:

ZachNagengast commented 6 months ago

@jkrukowski I push a small commit to measure the logit filtering time, here is what I'm getting for tiny with and without these new timestamp rules on the jfk.wav file: With: [WhisperKit] - Logit Filtering: 192.41 ms / 28 runs ( 6.87 ms/run) 37.78% Without: [WhisperKit] - Logit Filtering: 0.07 ms / 28 runs ( 0.00 ms/run) 0.02%

This is a bit high, it becomes especially noticeable with the tiny model. Something that is interesting is that only the first and last few tokens are slow (graph by chatgpt). This is for the jfk.wav

image Hopefully this gives you some guidance on where to look for optimizations. And the majority of the slowdown is in this block of code:

            // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
            let sampledTokens = tokens[sampleBegin...]
            let lastWasTimestamp = sampledTokens.count >= 1 && sampledTokens.last! >= timeTokenBegin
            let penultimateWasTimestamp = sampledTokens.count < 2 || sampledTokens.dropLast().last! >= timeTokenBegin
            if lastWasTimestamp {
                if penultimateWasTimestamp {
                    // has to be non-timestamp
                    logits.fillLastDimension(indexes: timeTokenBegin..<logits.count, with: -FloatType.infinity)
                } else {
                    // cannot be normal text tokens
                    logits.fillLastDimension(indexes: 0..<endToken, with: -FloatType.infinity)
                }
            }
jkrukowski commented 6 months ago

@jkrukowski I push a small commit to measure the logit filtering time, here is what I'm getting for tiny with and without these new timestamp rules on the jfk.wav file: With: [WhisperKit] - Logit Filtering: 192.41 ms / 28 runs ( 6.87 ms/run) 37.78% Without: [WhisperKit] - Logit Filtering: 0.07 ms / 28 runs ( 0.00 ms/run) 0.02%

This is a bit high, it becomes especially noticeable with the tiny model. Something that is interesting is that only the first and last few tokens are slow (graph by chatgpt). This is for the jfk.wav

image Hopefully this gives you some guidance on where to look for optimizations. And the majority of the slowdown is in this block of code:

            // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
            let sampledTokens = tokens[sampleBegin...]
            let lastWasTimestamp = sampledTokens.count >= 1 && sampledTokens.last! >= timeTokenBegin
            let penultimateWasTimestamp = sampledTokens.count < 2 || sampledTokens.dropLast().last! >= timeTokenBegin
            if lastWasTimestamp {
                if penultimateWasTimestamp {
                    // has to be non-timestamp
                    logits.fillLastDimension(indexes: timeTokenBegin..<logits.count, with: -FloatType.infinity)
                } else {
                    // cannot be normal text tokens
                    logits.fillLastDimension(indexes: 0..<endToken, with: -FloatType.infinity)
                }
            }

@ZachNagengast I've added more performant version of fillLastDimension function, seems like it's doing better, this is what I get for the release build on the jfk.wav file:

[WhisperKit] ---- Transcription Timings ----
[WhisperKit] Audio Load:              2.33 ms /      1 runs (    2.33 ms/run)  0.66%
[WhisperKit] Audio Processing:        0.11 ms /      1 runs (    0.11 ms/run)  0.03%
[WhisperKit] Mels:                   35.53 ms /      1 runs (   35.53 ms/run) 10.11%
[WhisperKit] Encoding:               13.39 ms /      1 runs (   13.39 ms/run)  3.81%
[WhisperKit] Matrices Init:           0.22 ms /      1 runs (    0.22 ms/run)  0.06%
[WhisperKit] Prefill:                 0.00 ms /      1 runs (    0.00 ms/run)  0.00%
[WhisperKit] Decoding:              239.40 ms /     28 runs (    8.55 ms/run) 68.15%
[WhisperKit] Non-inference:          61.25 ms /     28 runs (    2.19 ms/run) 17.43%
[WhisperKit] - Logit Filtering:       3.24 ms /     28 runs (    0.12 ms/run)  0.92%
[WhisperKit] - Sampling:             14.17 ms /     28 runs (    0.51 ms/run)  4.03%
[WhisperKit] - Kv Caching:            2.79 ms /     28 runs (    0.10 ms/run)  0.80%
[WhisperKit] - Word Timestamps:       0.00 ms /      0 runs (    0.00 ms/run)  0.00%
[WhisperKit] - Windowing:             0.08 ms /      1 runs (    0.08 ms/run)  0.02%
[WhisperKit] Fallbacks:               0.00 ms /      0 runs (    0.00 ms/run)  0.00%
[WhisperKit] Decoding Full Loop:    351.06 ms /     28 runs (   12.54 ms/run) 99.93%
ZachNagengast commented 6 months ago

Much better! This looks in line with what I was seeing for those faster middle tokens previously. Think this is ready to come out of draft now?

jkrukowski commented 6 months ago

Much better! This looks in line with what I was seeing for those faster middle tokens previously. Think this is ready to come out of draft now?

good to hear this, 2 things are left:

  1. self.sampleBegin = 3 // FIXME: it should not be hardcoded value -- not sure what value should I put there
  2. force unwrapping in sumOfProbabilityOverTimestampsIsAboveAnyOtherToken maybe we should not force unwrap and return false gracefully, wdyt?
ZachNagengast commented 6 months ago
  1. self.sampleBegin = 3 // FIXME: it should not be hardcoded value -- not sure what value should I put there

PrefilledIndex is already being passed into this function, but I think actually it should use intialPromptIndex. A good test to add for accuracy on this would be similar to this one https://github.com/argmaxinc/WhisperKit/blob/e45dc0a056197c4a4ee3dabe9c604f48b150e519/Tests/WhisperKitTests/UnitTests.swift#L314 where you'd create a bunch of options that change this initialPromptIndex and make sure it's working properly.

  1. force unwrapping in sumOfProbabilityOverTimestampsIsAboveAnyOtherToken maybe we should not force unwrap and return false gracefully, wdyt?

Besides the verbosity I think it's ok. If you want to be extra safe, you can wrap that whole part in a do catch and log an error similar to the sampling code. I'm not sure all the scenarios where BNNS will throw, but returning false would just fallback to default behavior so no issues there.