rustformers / llm

[Unmaintained, see README] An ecosystem of Rust libraries for working with large language models
https://docs.rs/llm/latest/llm/
Apache License 2.0
6.06k stars 350 forks source link

NaN logits on LLaMA 65B when using 2k+ token contexts #418

Open hugoabonizio opened 10 months ago

hugoabonizio commented 10 months ago

I'm trying to make inferences using more than 2k token contexts, but I'm having some trouble making it work for 65B models. The following code works on 7B scale models, but returns token sampling failed (due to nan logits) when using 65B models.

I'm trying with internal models, but using these 7B and 65B reproduce the issue.

use std::{io::Write, path::PathBuf};
use llm::Model;

fn main() {
    let llama = llm::load::<llm::models::Llama>(
        std::path::Path::new("/data/tmp/llama-65b.ggmlv3.q4_0.bin"),
        // std::path::Path::new("/data/tmp/llama-7b.ggmlv3.q4_0.bin"),
        llm::TokenizerSource::HuggingFaceTokenizerFile(PathBuf::from("/data/tmp/tokenizer.json").to_owned()),
        llm::ModelParameters {
            use_gpu: true,
            gpu_layers: Some(99),
            context_size: 8192,
            rope_overrides: Some(llm::RoPEOverrides {
                frequency_scale: 0.25,
                ..Default::default()
            }),
            ..Default::default()
        },
        llm::load_progress_callback_stdout
    )
    .unwrap_or_else(|err| panic!("Failed to load model: {err}"));

    println!("\n\ncontext_size {}", llama.context_size());

    let prompt = "hello ".repeat(2800); // works until 2k tokens

    let mut session = llama.start_session(llm::InferenceSessionConfig {
        n_batch: 256,
        ..Default::default()
    });

    let res = session.infer::<std::convert::Infallible>(
        &llama,
        &mut rand::thread_rng(),
        &llm::InferenceRequest {
            prompt: (&prompt).into(),
            parameters: &llm::InferenceParameters::default(),
            play_back_previous_tokens: false,
            maximum_token_count: Some(1),
        },
        &mut Default::default(),
        |r| match r {
            _ => Ok(llm::InferenceFeedback::Continue),
        }
    );

    match res {
        Ok(result) => println!("\n\nInference stats:\n{result}"),
        Err(err) => println!("\n{err}"),
    }
}