huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.37k stars 905 forks source link

Error: Metal error Error while loading function: "Function 'cast_bf16_f16' does not exist" with llama3 #2163

Open yIllusionSky opened 5 months ago

yIllusionSky commented 5 months ago

I made some modifications to the example code in llama3 for it to run locally, but I encountered an error during execution. I am using a MacBook with an M3 chip. Below is the Rust code (I have omitted some irrelevant parts):

mod token;
use candle_core::{backend::BackendDevice, safetensors, DType, Device, MetalDevice, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
    generation::{LogitsProcessor, Sampling},
    models::llama::Config,
};

use std::{error::Error, io::Write};

use candle_transformers::models::llama as model;
use model::{Llama, LlamaConfig};

const EOS_TOKEN: &str = "</s>";
const DEFAULT_PROMPT: &str = "My favorite theorem is ";

fn main() -> anyhow::Result<()> {
    let device = Device::new_metal(0)?;
    let dtype = DType::F16;
    let (llama, tokenizer_filename, mut cache, config) = {
        let config: LlamaConfig =
            serde_json::from_slice(&std::fs::read("model/config.json").unwrap()).unwrap();
        let filenames = vec![
            "model/model-00001-of-00005.safetensors",
            "model/model-00002-of-00005.safetensors",
            "model/model-00003-of-00005.safetensors",
            "model/model-00004-of-00005.safetensors",
            "model/model-00005-of-00005.safetensors",
        ];
        let config = config.into_config(true);
        let cache = model::Cache::new(false, dtype, &config, &device)?;

        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
        (
            Llama::load(vb, &config)?,
            "model/tokenizer.json",
            cache,
            config,
        )
    };
    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename).unwrap();
    let eos_token_id = config
        .eos_token_id
        .or_else(|| tokenizer.token_to_id(EOS_TOKEN));
    let prompt = DEFAULT_PROMPT;
    let mut tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
    let mut tokenizer = token::TokenOutputStream::new(tokenizer);

    println!("starting the inference loop");
    print!("{prompt}");
    let mut logits_processor = {
        let temperature = 0.8;
        let sampling = if temperature <= 0. {
            Sampling::ArgMax
        } else {
            match (None, None) {
                (None, None) => Sampling::All { temperature },
                (Some(k), None) => Sampling::TopK { k, temperature },
                (None, Some(p)) => Sampling::TopP { p, temperature },
                (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
            }
        };
        LogitsProcessor::from_sampling(299792458, sampling)
    };

    let mut start_gen = std::time::Instant::now();
    let mut index_pos = 0;
    let mut token_generated = 0;
    for index in 0..10000 {
        let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
            (1, index_pos)
        } else {
            (tokens.len(), 0)
        };
        if index == 1 {
            start_gen = std::time::Instant::now()
        }
        let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
        let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
        let logits = llama.forward(&input, context_index, &mut cache)?;
        let logits = logits.squeeze(0)?;
        let logits = if 1.1 == 1. {
            logits
        } else {
            let start_at = tokens.len().saturating_sub(128);
            candle_transformers::utils::apply_repeat_penalty(&logits, 1.1, &tokens[start_at..])?
        };
        index_pos += ctxt.len();

        let next_token = logits_processor.sample(&logits)?;
        token_generated += 1;
        tokens.push(next_token);

        if Some(next_token) == eos_token_id {
            break;
        }
        if let Some(t) = tokenizer.next_token(next_token)? {
            print!("{t}");
            std::io::stdout().flush()?;
        }
    }
    if let Some(rest) = tokenizer.decode_rest().unwrap() {
        print!("{rest}");
    }
    let dt = start_gen.elapsed();
    println!(
        "\n\n{} tokens generated ({} token/s)\n",
        token_generated,
        (token_generated - 1) as f64 / dt.as_secs_f64(),
    );
    Ok(())
}

When I execute cargo run, the following error occurs:

Error: Metal error Error while loading function: "Function 'cast_bf16_f16' does not exist"

Caused by:
    Error while loading function: "Function 'cast_bf16_f16' does not exist"
LaurentMazare commented 5 months ago

I used to have some similar issue in the past where all the bf16 kernels where missing. This was due to an outdated gcc/g++ toolchain (using version 12.0), using the latest version that comes with xcode (15.0) fixed it for me so maybe you could try this. Fwiw here is much shorter example that would have triggered the issue.

fn main() -> Result<()> {
    let device = Device::new_metal(0)?;
    let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &device)?;
    let a = a.to_dtype(candle_core::DType::BF16)?;
    let a = a.to_dtype(candle_core::DType::F16)?;
    println!("{a}");
    Ok(())
}
yIllusionSky commented 5 months ago

I used to have some similar issue in the past where all the bf16 kernels where missing. This was due to an outdated gcc/g++ toolchain (using version 12.0), using the latest version that comes with xcode (15.0) fixed it for me so maybe you could try this. Fwiw here is much shorter example that would have triggered the issue.

fn main() -> Result<()> {
    let device = Device::new_metal(0)?;
    let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &device)?;
    let a = a.to_dtype(candle_core::DType::BF16)?;
    let a = a.to_dtype(candle_core::DType::F16)?;
    println!("{a}");
    Ok(())
}

My Xcode version is 15.3. I just checked my GCC version with gcc -v and it is 15.0. I tried the code you provided and still encountered the error: 'Error: Metal error Error while loading function: "Function 'cast_f32_bf16' does not exist".' I have tried specifying dependencies with git and version 0.50, but neither worked.

zackangelo commented 2 months ago

@LaurentMazare would it make sense to compile the kernels into a metallib file in a cargo build script? I'm running into an unrelated issue where a to_vec1 call after the first forward pass is very slow and it's been suggested it's due to the kernel compilation that needs to happen.