huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.17k stars 887 forks source link

Error: DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory") with multiple GPU #2046

Open evgenyigumnov opened 5 months ago

evgenyigumnov commented 5 months ago

I have: 4x RTX 3080 = 40GB total memory (each GPU by 10 GB memory)

I try to load model Mistral 7 about 15Gb file.

But I take error:

root@C.10529376:~/ai-server$ cargo run
    Finished dev [unoptimized + debuginfo] target(s) in 0.20s
     Running `target/debug/ai-server`
retrieved the files in 27.070873ms
Error: DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory")

Is it possible to run on multiple GPU mode?

tomsanbear commented 5 months ago

Definitely possible but won't just work with standard examples. Please take a look at this example for how to run across multiple GPUs: https://github.com/huggingface/candle/tree/main/candle-examples/examples/llama_multiprocess

evgenyigumnov commented 5 months ago

Unfortunately it didn't help

PS C:\Users\igumn\ai-server> cargo run
   Compiling ai-server v0.1.0 (C:\Users\igumn\ai-server)
error[E0616]: field `hidden_size` of struct `candle_transformers::models::mistral::Config` is private
  --> src\model.rs:19:29
   |
19 |         let n_elem = config.hidden_size / config.num_attention_heads;
   |                             ^^^^^^^^^^^ private field

error[E0616]: field `num_attention_heads` of struct `candle_transformers::models::mistral::Config` is private
  --> src\model.rs:19:50
   |
19 |         let n_elem = config.hidden_size / config.num_attention_heads;
   |                                                  ^^^^^^^^^^^^^^^^^^^ private field

error[E0616]: field `rope_theta` of struct `candle_transformers::models::mistral::Config` is private
  --> src\model.rs:22:36
   |
22 |             .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
   |                                    ^^^^^^^^^^ private field

error[E0616]: field `num_hidden_layers` of struct `candle_transformers::models::mistral::Config` is private
  --> src\model.rs:34:56
   |
34 |             kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
   |                                                        ^^^^^^^^^^^^^^^^^ private field

error[E0061]: this function takes 2 arguments but 4 arguments were supplied
   --> src\main.rs:379:25
    |
379 |             let model = Mistral::new(vb, &cache, &config, comm)?;
    |                         ^^^^^^^^^^^^ --  ------ unexpected argument of type `&model::Cache`
    |                                      |
    |                                      unexpected argument of type `VarBuilderArgs<'_, ShardedSafeTensors>`
    |
note: expected `VarBuilderArgs<'_, Box<dyn SimpleBackend>>`, found `Rc<Comm>`
   --> src\main.rs:379:59
    |
379 |             let model = Mistral::new(vb, &cache, &config, comm)?;
    |                                                           ^^^^
    = note: expected struct `VarBuilderArgs<'_, Box<dyn SimpleBackend>>`
               found struct `Rc<Comm>`
note: associated function defined here
   --> C:\Users\igumn\.cargo\registry\src\index.crates.io-6f17d22bba15001f\candle-transformers-0.4.1\src\models\mistral.rs:383:12
    |
383 |     pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
    |            ^^^
help: remove the extra arguments
    |
379 -             let model = Mistral::new(vb, &cache, &config, comm)?;
379 +             let model = Mistral::new(, &config, /* VarBuilderArgs<'_, Box<dyn SimpleBackend>> */)?;
    |

warning: unused import: `candle_core::backend::BackendDevice`
  --> src\main.rs:18:5
   |
18 | use candle_core::backend::BackendDevice;
   |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   |
   = note: `#[warn(unused_imports)]` on by default

warning: unused import: `candle_core::backend::BackendStorage`
 --> src\model.rs:2:5
  |
2 | use candle_core::backend::BackendStorage;
  |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Some errors have detailed explanations: E0061, E0616.
For more information about an error, try `rustc --explain E0061`.
warning: `ai-server` (bin "ai-server") generated 2 warnings
error: could not compile `ai-server` (bin "ai-server") due to 5 previous errors; 2 warnings emitted
[package]
name = "ai-server"
version = "0.1.0"
edition = "2021"

[dependencies]

candle-nn = "0.4.1"
candle-core = "0.4.1"
candle-datasets = "0.4.1"
candle-transformers = "0.4.1"
candle-examples = "0.4.1"
hf-hub = "0.3.2"
tokenizers = "0.15.2"
anyhow = "1.0.81"
clap = { version = "4.5.3", features = ["derive"] }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
axum = "0.7.5"
serde = { version = "1.0.197", features = ["derive"] }
tokio = "1.36.0"
once_cell = "1.19.0"
futures = "0.3.30"
cudarc = "0.10.0"

[build-dependencies]
bindgen_cuda = { version = "0.1.1", optional = true }

[features]
default = ["cuda"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda", "cudarc/nccl"]

model.rs

use candle_core::backend::BackendStorage;
use candle_core::{ DType, Device, Result, Tensor};
use std::sync::{Arc, Mutex};
const MAX_SEQ_LEN: usize = 4096;
pub type Config = candle_transformers::models::mistral::Config;

#[derive(Clone)]
pub struct Cache {
    #[allow(clippy::type_complexity)]
    kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
    cos: Tensor,
    sin: Tensor,
}

impl Cache {
    pub fn new(dtype: DType, config: &Config, device: &Device) -> Result<Self> {
        // precompute freqs_cis
        let n_elem = config.hidden_size / config.num_attention_heads;
        let theta: Vec<_> = (0..n_elem)
            .step_by(2)
            .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
            .collect();
        let theta = Tensor::new(theta.as_slice(), device)?;
        let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
            .to_dtype(DType::F32)?
            .reshape((MAX_SEQ_LEN, 1))?
            .matmul(&theta.reshape((1, theta.elem_count()))?)?;
        // This is different from the paper, see:
        // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
        let cos = idx_theta.cos()?.to_dtype(dtype)?;
        let sin = idx_theta.sin()?.to_dtype(dtype)?;
        Ok(Self {
            kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
            cos,
            sin,
        })
    }
}

main.rs

use std::io::Write;
use std::rc::Rc;
use axum::{
    routing::post,
    routing::get,
    Json, Router,
};
use serde::{Deserialize, Serialize};
use axum::{response::Html};

use anyhow::{Error as E, Result};
use clap::Parser;

use candle_transformers::models::mistral::{Config, Model as Mistral};

use candle_core::{DType, Device, Tensor};
use candle_core::backend::BackendDevice;
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
use once_cell::sync::Lazy;
use futures::lock::Mutex;
use cudarc::nccl::safe::{Comm, Id};

static AI_SERVER: Lazy<Mutex<Option<TextGeneration>>> = Lazy::new(|| Mutex::new(None));

mod model;

#[tokio::main]
async fn main() -> anyhow::Result<()> {

    let args = Args::parse();
    let ai = init(&args)?;
    AI_SERVER.lock().await.replace(ai);

    let app = Router::new().route("/", post(handle_request)).route("/", get(handler));

    let listener = tokio::net::TcpListener::bind("127.0.0.1:8181")
        .await
        .unwrap();
    println!("listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();

    Ok(())

}

#[derive(Deserialize)]
struct Input {
    doc1: String,
    doc2: String,
}

#[derive(Serialize)]
struct Response {
    result: String,
}

async fn handle_request(Json(payload): Json<Input>) -> Json<Response> {

    let prompt = format!(r#"
{}

Which category does the text above belong to? Here is the list of categories:

{}

Give me a short answer. Just a category number. Category number:"#, payload.doc1, payload.doc2);

    let mut ai_server_mut = AI_SERVER.lock().await;
    let ai_server_opt = ai_server_mut.as_mut();
    match ai_server_opt {
        Some(ai_server) => {
            let result_opt = ai_server.run(&prompt, 3);
            match result_opt {
                Ok(result) => {
                    Json(Response { result: result.replace("%","").replace("\n","") })
                }
                Err(e) => {
                    Json(Response { result: e.to_string() })
                }
            }
        }
        None => {
            Json(Response { result: "AI server not initialized".to_string() })
        }
    }
}

async fn handler() -> Html<&'static str> {
    Html("<h1>Server status: online</h1>")
}

struct TextGeneration {
    model: Mistral,
    device: Device,
    tokenizer: TokenOutputStream,
    logits_processor: LogitsProcessor,
    repeat_penalty: f32,
    repeat_last_n: usize,
}

impl TextGeneration {
    #[allow(clippy::too_many_arguments)]
    fn new(
        model: Mistral,
        tokenizer: Tokenizer,
        seed: u64,
        temp: Option<f64>,
        top_p: Option<f64>,
        repeat_penalty: f32,
        repeat_last_n: usize,
        device: &Device,
    ) -> Self {
        let logits_processor = LogitsProcessor::new(seed, temp, top_p);
        Self {
            model,
            tokenizer: TokenOutputStream::new(tokenizer),
            logits_processor,
            repeat_penalty,
            repeat_last_n,
            device: device.clone(),
        }
    }

    fn run(&mut self, prompt: &str, sample_len: usize) -> Result<String> {
        let mut result_string = String::new();
        self.tokenizer.clear();
        self.model.clear_kv_cache();
        let mut tokens = self
            .tokenizer
            .tokenizer()
            .encode(prompt, true)
            .map_err(E::msg)?
            .get_ids()
            .to_vec();
        println!("{}", prompt);
        std::io::stdout().flush()?;

        let mut generated_tokens = 0usize;
        let eos_token = match self.tokenizer.get_token("</s>") {
            Some(token) => token,
            None => anyhow::bail!("cannot find the </s> token"),
        };
        let start_gen = std::time::Instant::now();
        for index in 0..sample_len {
            let context_size = if index > 0 { 1 } else { tokens.len() };
            let start_pos = tokens.len().saturating_sub(context_size);
            let ctxt = &tokens[start_pos..];
            let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
            let logits = self.model.forward(&input, start_pos)?;
            let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
            let logits = if self.repeat_penalty == 1. {
                logits
            } else {
                let start_at = tokens.len().saturating_sub(self.repeat_last_n);
                candle_transformers::utils::apply_repeat_penalty(
                    &logits,
                    self.repeat_penalty,
                    &tokens[start_at..],
                )?
            };

            let next_token = self.logits_processor.sample(&logits)?;
            tokens.push(next_token);
            generated_tokens += 1;
            if next_token == eos_token {
                break;
            }
            if let Some(t) = self.tokenizer.next_token(next_token)? {
                result_string.push_str(&t);
                print!("{t}");
                std::io::stdout().flush()?;
            }
        }
        let dt = start_gen.elapsed();
        if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
            result_string.push_str(&rest);
            print!("{rest}");
        }
        std::io::stdout().flush()?;
        println!(
            "\n{generated_tokens} tokens generated ({:.2} token/s)",
            generated_tokens as f64 / dt.as_secs_f64(),
        );

        Ok(result_string)
    }
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
    #[arg(long)]
    num_shards: usize,

    #[arg(long)]
    rank: Option<usize>,

    #[arg(long)]
    start_port: Option<usize>,

    #[arg(long)]
    cpu: bool,
    #[arg(long)]
    tracing: bool,

    #[arg(long)]
    use_flash_attn: bool,

    #[arg(long)]
    temperature: Option<f64>,

    #[arg(long)]
    top_p: Option<f64>,

    #[arg(long, default_value_t = 299792458)]
    seed: u64,

    #[arg(long)]
    model_id: Option<String>,

    #[arg(long, default_value = "main")]
    revision: String,

    #[arg(long)]
    tokenizer_file: Option<String>,

    #[arg(long)]
    weight_files: Option<String>,

    #[arg(long, default_value_t = 1.1)]
    repeat_penalty: f32,

    #[arg(long, default_value_t = 64)]
    repeat_last_n: usize,
}

fn init(args: &Args) -> Result<TextGeneration> {
    use tracing_chrome::ChromeLayerBuilder;
    use tracing_subscriber::prelude::*;

    // let args = Args {
    //     cpu: false,
    //     tracing: false,
    //     use_flash_attn: false,
    //     temperature: None,
    //     top_p: None,
    //     seed: 299792458,
    //     model_id: None,
    //     revision: "main".to_string(),
    //     tokenizer_file: None,
    //     weight_files: None,
    //     repeat_penalty: 1.1,
    //     repeat_last_n: 3,
    // };

    let _guard = if args.tracing {
        let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
        tracing_subscriber::registry().with(chrome_layer).init();
        Some(guard)
    } else {
        None
    };

    let start = std::time::Instant::now();
    let api = Api::new()?;
    let model_id = match &args.model_id {
        Some(model_id) => model_id.to_string(),
        None => {
            "mistralai/Mistral-7B-v0.1".to_string()
        }
    };
    let repo = api.repo(Repo::with_revision(
        model_id,
        RepoType::Model,
        args.revision.to_string(),
    ));
    let tokenizer_filename = match &args.tokenizer_file {
        Some(file) => std::path::PathBuf::from(file),
        None => repo.get("tokenizer.json")?,
    };
    let filenames = match &args.weight_files {
        Some(files) => files
            .split(',')
            .map(std::path::PathBuf::from)
            .collect::<Vec<_>>(),
        None => {
            candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
        }
    };
    println!("retrieved the files in {:?}", start.elapsed());

    if args.rank.is_none() && args.start_port.is_some() {
        let children: Vec<_> = (0..args.num_shards)
            .map(|rank| {
                let mut args: std::collections::VecDeque<_> = std::env::args().collect();
                args.push_back("--rank".to_string());
                args.push_back(format!("{rank}"));
                let name = args.pop_front().unwrap();
                std::process::Command::new(name).args(args).spawn().unwrap()
            })
            .collect();
        for mut child in children {
            child.wait().unwrap();
        }
        return Err(E::msg("all children have exited"));
    }

    let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

    let start = std::time::Instant::now();
    let config = Config::config_7b_v0_1(args.use_flash_attn);

    if args.start_port.is_some() {

        let comm_file: String = "nccl_id.txt".to_string();

        let comm_file = std::path::PathBuf::from(comm_file);
        if comm_file.exists() {
            return Err(E::msg("comm file already exists, please remove it first"));
        }

        let (model, device) = {
            let num_shards = args.num_shards;
            // Primitive IPC
            let id = if args.rank.unwrap() == 0 {
                let id = Id::new().unwrap();
                let tmp_file = comm_file.with_extension(".comm.tgz");
                std::fs::File::create(&tmp_file)?
                    .write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
                std::fs::rename(&tmp_file, &comm_file)?;
                id
            } else {
                while !comm_file.exists() {
                    std::thread::sleep(std::time::Duration::from_secs(1));
                }
                let data = std::fs::read(&comm_file)?;
                let internal: [i8; 128] = data
                    .into_iter()
                    .map(|i| i as i8)
                    .collect::<Vec<_>>()
                    .try_into()
                    .unwrap();
                let id: Id = Id::uninit(internal);
                id
            };

            let device_arc = cudarc::driver::CudaDevice::new(args.rank.unwrap())?;
            let device = Device::new_cuda(args.rank.unwrap())?;
            let dtype = if device.is_cuda() {
                DType::BF16
            } else {
                DType::F32
            };

            let cache = model::Cache::new(dtype, &config, &device)?;

            let vb = unsafe {
                candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
            };

            let comm = match Comm::from_rank(device_arc, args.rank.unwrap(), args.num_shards, id) {
                Ok(comm) => Rc::new(comm),
                Err(err) => anyhow::bail!("nccl error {:?}", err.0),
            };
            let model = Mistral::new(vb, &cache, &config, comm)?;

            (model, device)
        };

        println!("loaded the model in {:?}", start.elapsed());
        let pipeline = TextGeneration::new(
            model,
            tokenizer,
            299792458,
            Some(0.0),
            Some(00.0),
            1.1,
            64,
            &device,
        );

        Ok(pipeline)

    } else {
        let device = candle_examples::device(false)?;

        let (model, device) = {
            let dtype = if device.is_cuda() {
                DType::BF16
            } else {
                DType::F32
            };
            let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
            let model = Mistral::new(&config, vb)?;
            (model, device)
        };

        println!("loaded the model in {:?}", start.elapsed());
        let pipeline = TextGeneration::new(
            model,
            tokenizer,
            299792458,
            Some(0.0),
            Some(00.0),
            1.1,
            64,
            &device,
        );

        Ok(pipeline)

    }

}

As I understand it, these errors can disappear if changes are made to the following files:

  1. candle_transformers::models::mistral::Config
  2. candle-transformers-0.4.1\src\models\mistral.rs

right?