huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
13.79k stars 751 forks source link

Phi-3 implementation seems to be buggy on metal devices #2128

Open jorgeantonio21 opened 2 weeks ago

jorgeantonio21 commented 2 weeks ago

After running multiple times the command:

cargo run --release --features metal --example phi -- --model 3 --prompt "The best thing about coding in rust is "

I realized a very degrading performance in the token generation time, on my Macbook Pro M3. After profiling the issue, realized that with a repeat_penalty = 1.1, there is roughly (on avg) 50secs spent on

candle_transformers::utils::apply_repeat_penalty(
                    &logits,
                    self.repeat_penalty,
                    &tokens[start_at..],
                )?

and notably, the time taken is mostly on the following:

let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::<f32>()?;

I find this weird, as my Macbook Pro M3 is very fast for a Mamba 2.8b model, which is roughly the size of Phi-3 mini. Also the above operation is an allocation of a data buffer of roughly 30/40 thousand f32's, which is definitely not that large.

As a follow-up question, a few lines above, I see that the forward pass is done through:

Model::Phi3(m) => m.forward(&input, pos)?.i((.., 0, ..))?,

which might have an impact on the display of the tensor and therefore it might affect the allocation. I can keep investigating though.

jorgeantonio21 commented 2 weeks ago

PS: I haven't tested this same command on cuda devices, yet.

LaurentMazare commented 2 weeks ago

One thing that might be tricky here is that the metal api might act in an asynchronous way. This could explain why you end up spending most of the time in the operation that retrieves the logits from the device. One way to check this would be to remove the repeat penalty and check that the slowness properly disappears there (there is also a device.synchronize() that should ensure that all the gpu ops are complete at that point but I think it might be broken for metal at the moment).

jorgeantonio21 commented 2 weeks ago

After inspection, and remove the penalty. I realized there is also a considerable amount of time spent on

let next_token = self.logits_processor.sample(&logits)?;

Regarding the asynchronicity of the metal api, why doesn't cause issues for other small models, like mamba 2.8b or llama tiny ?

jorgeantonio21 commented 2 weeks ago

PS: after running on a RTX4090, I find it particularly fast to run inference on this model, roughly 90 tokens/sec.

LaurentMazare commented 2 weeks ago

I just gave it a try on my macbook M2 Pro 16GB and it's indeed extremly slow. I think it might be because we make the computations using f32 so with 3.8b parameters this fills the memory as it would take roughly the 4*3.8GB (also f32 computations are a lot slower than bf16 ones). The model was initially created in bf16 which is what gets used with cuda but this is not available in candle at the moment, it's being worked on though. I've just added a new parameter so that you can specify the dtype via --dtype f16 for example. With this you can try using f16, it seems to be a lot faster but note that f16 has a much narrower range than bf16 so it's quite possible that you would run into nans when trying to use it.