Closed neilmehta24 closed 1 month ago
I will investigate this
I can see two things that look like they are contributing here:
StreamingDetokenizer
-- if I generate 100 tokens the python and swift version are nearly the same speed (around 100 tokens per second on my laptop)
In swift we are calling the Tokenizer with the entire list of output tokens and taking any additions to the resulting string to print out. This is $O(n^2)$ performance as we need to scan $n$ tokens $n$ times. The python version has StreamingDetokenizer
that gets O(n)
performance as it only generates the tail end of the output string as it runs -- we need a version of this in swift.
I suspect you were seeing the latter effect (you probably ran this a few times and the first effect was negligible).
TASKs:
StreamingDetokenizer
mx.async_eval(y)
to pipeline the generationKVCache
from the python side as wellWe can do these in that order as the first one is probably the biggest benefit.
port StreamingDetokenizer
In Python we have a naive detokenizer that chops the history on every line break to avoid needing to re-decode the full sequence. That actually gets you pretty far and is quite simple to implement. The full streaming detokenizers add some speed after that.. but they are more involved to implement since there are a few cases for different models.
use mx.async_eval(y) to pipeline the generation
That should be fairly simple to add. It's like a four line change in Python
look at KVCache from the python side as well
That actually makes a noticeable difference for longer generations
Another optimization in Python which is really useful for long prompts/generations https://github.com/ml-explore/mlx-examples/pull/931
There are two things there
The prompt splitting is an easy win / no brainer. Basically four lines for faster / lower memory prompt processing:
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
The rotating buffer is more involved but useful for memory constrained situations (at the cost of accuracy. We can look at adding that after the other items above.
The performance should be roughly the same as python now, though I found both of them to be a little noisy in the measurement. See #109
The python
mlx_lm
implementation generates at ~101 tokens per second formlx-community/Phi-3-mini-4k-instruct-4bit
, whereas the swift code here generates at ~60 tokens per second.Here is my python implementation
Here is my swift command
Any ideas on how I can achieve similar speed in swift?