ml-explore / mlx-swift-examples

Examples using MLX Swift
MIT License
966 stars 102 forks source link

Swift generates tokens substantially slower than python for Phi-3 #93

Closed neilmehta24 closed 1 month ago

neilmehta24 commented 3 months ago

The python mlx_lm implementation generates at ~101 tokens per second for mlx-community/Phi-3-mini-4k-instruct-4bit, whereas the swift code here generates at ~60 tokens per second.

Here is my python implementation

from mlx_lm import load, generate

model, tokenizer = load("mlx-community/Phi-3-mini-4k-instruct-4bit")
response = generate(model, tokenizer, max_tokens=100000, prompt="""<|system|>
You are a helpful assistant.<|end|>
<|user|>
How to explain Internet for a medieval knight?<|end|>
<|assistant|>""", verbose=True)

Here is my swift command

➜  mlx-swift-examples git:(main) ./mlx-run llm-tool eval --model mlx-community/Phi-3-mini-4k-instruct-4bit -m 100000 -p "<|system|>
You are a helpful assistant.<|end|>
<|user|>
How to explain Internet for a medieval knight?<|end|>
<|assistant|>"

Any ideas on how I can achieve similar speed in swift?

davidkoski commented 3 months ago

I will investigate this

davidkoski commented 3 months ago

I can see two things that look like they are contributing here:

In swift we are calling the Tokenizer with the entire list of output tokens and taking any additions to the resulting string to print out. This is $O(n^2)$ performance as we need to scan $n$ tokens $n$ times. The python version has StreamingDetokenizer that gets O(n) performance as it only generates the tail end of the output string as it runs -- we need a version of this in swift.

I suspect you were seeing the latter effect (you probably ran this a few times and the first effect was negligible).

davidkoski commented 3 months ago

TASKs:

We can do these in that order as the first one is probably the biggest benefit.

awni commented 3 months ago

port StreamingDetokenizer

In Python we have a naive detokenizer that chops the history on every line break to avoid needing to re-decode the full sequence. That actually gets you pretty far and is quite simple to implement. The full streaming detokenizers add some speed after that.. but they are more involved to implement since there are a few cases for different models.

use mx.async_eval(y) to pipeline the generation

That should be fairly simple to add. It's like a four line change in Python

look at KVCache from the python side as well

That actually makes a noticeable difference for longer generations

awni commented 1 month ago

Another optimization in Python which is really useful for long prompts/generations https://github.com/ml-explore/mlx-examples/pull/931

There are two things there

  1. Prompt splitting
  2. Rotating buffer for the cache

The prompt splitting is an easy win / no brainer. Basically four lines for faster / lower memory prompt processing:

  while y.size > prefill_step_size:
       model(y[:prefill_step_size][None], cache=cache)
       mx.eval([c.state for c in cache])
       y = y[prefill_step_size:]

The rotating buffer is more involved but useful for memory constrained situations (at the cost of accuracy. We can look at adding that after the other items above.

davidkoski commented 1 month ago

The performance should be roughly the same as python now, though I found both of them to be a little noisy in the measurement. See #109