huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.82k stars 953 forks source link

Incompatible shapes when #2501

Open segeljakt opened 1 month ago

segeljakt commented 1 month ago

I tried to modify the code in https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs to become a chatbot where each new prompt considers the history of all previous prompts. This is my code:

use anyhow::{Error, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::Cache;
use candle_transformers::models::llama::Llama;
use candle_transformers::models::llama::LlamaConfig;
use candle_transformers::models::llama::LlamaEosToks;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;

const EOS_TOKEN: &str = "</s>";
const REPEAT_PENALTY: f32 = 1.1;
const REPEAT_LAST_N: usize = 128;
const SEED: u64 = 299792458;
const SAMPLE_LEN: usize = 10000;
const ADD_SPECIAL_TOKENS: bool = true;
const SKIP_SPECIAL_TOKENS: bool = true;
const USE_KV_CACHE: bool = true;
const USE_FLASH_ATTENTION: bool = false;

pub struct Chat {
    model: Llama,
    logits_processor: LogitsProcessor,
    cache: Cache,
    tokenizer: Tokenizer,
    device: Device,
    eos_token_id: Option<LlamaEosToks>,
    tokens: Vec<u32>,
    index: usize,
}

impl Chat {
    pub fn new() -> Result<Self> {
        let device = Device::new_metal(0)?;
        let dtype = DType::F16;
        let api = Api::new()?;
        let api = api.repo(Repo::with_revision(
            "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
            RepoType::Model,
            "main".to_string(),
        ));

        let tokenizer_filename = api.get("tokenizer.json")?;
        let config_filename = api.get("config.json")?;
        let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
        let config = config.into_config(USE_FLASH_ATTENTION);
        let filenames = vec![api.get("model.safetensors")?];
        let cache = Cache::new(USE_KV_CACHE, dtype, &config, &device)?;

        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
        let model = Llama::load(vb, &config)?;

        let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(Error::msg)?;
        let eos_token_id = config
            .eos_token_id
            .or_else(|| tokenizer.token_to_id(EOS_TOKEN).map(LlamaEosToks::Single));
        let logits_processor = LogitsProcessor::from_sampling(SEED, Sampling::ArgMax);

        Ok(Self {
            model,
            tokenizer,
            logits_processor,
            eos_token_id,
            cache,
            device,
            tokens: Vec::new(),
            index: 0,
        })
    }

    pub fn run(&mut self, prompt: &str) -> Result<String> {
        self.tokens.extend(
            self.tokenizer
                .encode(prompt, ADD_SPECIAL_TOKENS)
                .map_err(Error::msg)?
                .get_ids(),
        );

        for _ in 0..SAMPLE_LEN {
            let tokens_slice = &self.tokens[self.index..];
            let input = Tensor::new(tokens_slice, &self.device)?.unsqueeze(0)?;
            let logits = self
                .model
                .forward(&input, self.index, &mut self.cache)?
                .squeeze(0)?;
            let logits = candle_transformers::utils::apply_repeat_penalty(
                &logits,
                REPEAT_PENALTY,
                &self.tokens[self.tokens.len().saturating_sub(REPEAT_LAST_N)..],
            )?;
            self.index += tokens_slice.len();

            let next_token = self.logits_processor.sample(&logits)?;
            self.tokens.push(next_token);

            if self.is_eos_token(next_token) {
                break;
            }
        }
        let output = self
            .tokenizer
            .decode(&self.tokens, SKIP_SPECIAL_TOKENS)
            .map_err(Error::msg)?;
        Ok(output)
    }

    fn is_eos_token(&self, token: u32) -> bool {
        matches!(self.eos_token_id, Some(LlamaEosToks::Single(id)) if token == id)
            || matches!(self.eos_token_id, Some(LlamaEosToks::Multiple(ref ids)) if ids.contains(&token))
    }
}

fn main() {
    let mut chat = Chat::new().unwrap();
    println!("{}", chat.run("Hello my name is").unwrap());
    println!("{}", chat.run("Today").unwrap());
}

When I run, I get this error:

called `Result::unwrap()` on an `Err` value: BroadcastIncompatibleShapes { src_shape: [3, 3], dst_shape: [1, 32, 3, 349] }

The error happens the second time I call Chat::run in main and is thrown from this statement.

    // ...
            let logits = self
                .model
                .forward(&input, self.index, &mut self.cache)?
    // ...

The first time I run the chat in main, the shape of input is [1,5]. After producing an output token, the next shape of input is [1,1] since I use key-value caching.

When I later enter a new prompt and run the chat, the input shape is [1,3] (which includes the EOS token from the previous run). The error disappears if drop some tokens so the shape becomes [1,1]. Is there something that says the shape must be [1,1] when we use key-value caching?

segeljakt commented 1 month ago

Oh, I managed to get it working by turning off kv-caching and then turning it on again:

impl Chat {
    // ...
    pub fn run(&mut self, prompt: &str) -> Result<String> {
        self.tokens.extend(
            self.tokenizer
                .encode(prompt, ADD_SPECIAL_TOKENS)
                .map_err(Error::msg)?
                .get_ids(),
        );
        self.cache.use_kv_cache = false; // <---- Here

        for _ in 0..SAMPLE_LEN {
            let tokens_slice = &self.tokens[self.index..];
            let input = Tensor::new(tokens_slice, &self.device)?.unsqueeze(0)?;
            let logits = self
                .model
                .forward(&input, self.index, &mut self.cache)?
                .squeeze(0)?;
            self.cache.use_kv_cache = true;  // <---- Here
            let logits = candle_transformers::utils::apply_repeat_penalty(
                &logits,
                REPEAT_PENALTY,
                &self.tokens[self.tokens.len().saturating_sub(REPEAT_LAST_N)..],
            )?;
            // ...
        }
        // ...   
    }
    // ...
}

This means the first forward of every run is done without kv-caching. Is this the correct way to approach it?