huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
14.24k stars 797 forks source link

Mamba model is broken with `f16` precision #2280

Open jorgeantonio21 opened 1 week ago

jorgeantonio21 commented 1 week ago

Running the command:

cargo run --example mamba --release --features metal -- --prompt "Tell me a joke please" --dtype f16

does not work. The problem seems to lie in code:

for &t in tokens.iter() {
            let input = Tensor::new(&[t], &self.device)?;
            let logits = self.model.forward(&input, &mut state)?;
            next_logits = Some(logits);
            if let Some(t) = self.tokenizer.next_token(t)? {
                print!("{t}")
            }
        }

where logits is a Tensor of null values.

LaurentMazare commented 1 week ago

f16 has a far smaller range than f32 so it's quite common for models trained in f32 or bf16 to return some nans if you try to evaluate them in f16. Maybe you could try with the python version and see if it's the same? Alternatively you should be able to run it in bf16 though this will only work on cuda at the moment.