huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.35k stars 902 forks source link

Running models with different precisions #2032

Open jorgeantonio21 opened 5 months ago

jorgeantonio21 commented 5 months ago

I am testing different model architectures, and when loading the model weights (e.g. for falcon or mamba architectures) with precision either bf16 or f16 I usually get this error:

Candle error: 'unexpected dtype, expected: F32, got: BF16'

I am running the examples on candle, and passing in precision of f16 or bf16. Is there a way around this by tweaking the code ? Or should I load weights directly from f16/bf16 precision through some other repo in HuggingFace ?

LaurentMazare commented 5 months ago

Is it possible that it's only when retrieving the results back with to_vec<f32> or equivalent. With a var store, the conversion should be handled for you. When retrieving results, we error out when converting values back, hence you have to call to_dtype beforehand. A good way to know where the issue is coming from is to enable RUST_BACKTRACE=1, also you probably want the profile release-with-debug or the equivalent from your project so that line numbers are properly related. (just trying to guess here, happy to give a more in depth look if you can provide a simple repro)

jorgeantonio21 commented 5 months ago

Thank you @LaurentMazare ! The issue was, I believe, I was not converting the input tensor to a dtype other than f32. I refactored the code from

for &token in tokens.iter() {
            let input = Tensor::new(&[token], &self.device)?;
            let logits = self.model.forward(&input, &mut state)?;

            next_logits = Some(logits);
            if let Some(t) = self.tokenizer.next_token(token)? {
                output.push_str(t.as_str());
            }
        }

to

for &token in tokens.iter() {
            let input = Tensor::new(&[token], &self.device)?.to_dtype(self.device)?;
            let logits = self.model.forward(&input, &mut state)?;

            next_logits = Some(logits);
            if let Some(t) = self.tokenizer.next_token(token)? {
                output.push_str(t.as_str());
            }
        }

However, running the later code on my Macbook (with Metal features) I get the following error:

Candle error:Metal contiguous index_select BF16 BF16 not implemented`

Is it the case that current metal kernels do not support types other than f32 ?

LaurentMazare commented 5 months ago

Most metal ops should support f32, f16, and bf16, this one was missing somehow so I added it in #2035 That said, my macbook doesn't support bf16 so I wasn't able to really test but hopefully that will work for you.

jorgeantonio21 commented 5 months ago

Thanks a lot for the PR ! Unfortunately, I also have the same issue with other dtypes, including f16:

Candle error: Metal contiguous index_select F16 F16 not implemented

LaurentMazare commented 5 months ago

This one is different, you can only index in a tensor with an integer tensors so u32 f16 makes sense but f16 f16 wouldn't as the index cannot be a float.

jorgeantonio21 commented 5 months ago

I see, right. It seems though that many of these models do not have support for f16 or bf16. Without erroneously converting the indices to f16, I am getting this error:

dtype mismatch in mul, lhs: F16, rhs: F32.

I am running these experiments on mamba and falcon, and from the implementation it seems these models do not support other dtypes other than f32 (mamba state is hardcoded to be in f32 precision, whereas falcon the mask is also hardcoded on f32 precision.

I wonder, if it is possible to allow other precision types for these models (including f16 and bf16) ?

LaurentMazare commented 5 months ago

Yeah there is no real limitation for this, I've made #2036 for mamba. It works with bf16 but not with f16 (which is somewhat expected, models trained in f32 or bf16 are likely to break with f16). On my RTX 4080, speed slightly increases from 320 token/s to 360 token/s so wouldn't consider it as a big improvement.

jorgeantonio21 commented 5 months ago

This is interesting, on my Macbook pro machine it works with f16, but not with bf16. Thanks for the PR @LaurentMazare, it would be great to have this for both Llama and Falcon models, too.