Open evgenyigumnov opened 5 months ago
Definitely possible but won't just work with standard examples. Please take a look at this example for how to run across multiple GPUs: https://github.com/huggingface/candle/tree/main/candle-examples/examples/llama_multiprocess
Unfortunately it didn't help
PS C:\Users\igumn\ai-server> cargo run
Compiling ai-server v0.1.0 (C:\Users\igumn\ai-server)
error[E0616]: field `hidden_size` of struct `candle_transformers::models::mistral::Config` is private
--> src\model.rs:19:29
|
19 | let n_elem = config.hidden_size / config.num_attention_heads;
| ^^^^^^^^^^^ private field
error[E0616]: field `num_attention_heads` of struct `candle_transformers::models::mistral::Config` is private
--> src\model.rs:19:50
|
19 | let n_elem = config.hidden_size / config.num_attention_heads;
| ^^^^^^^^^^^^^^^^^^^ private field
error[E0616]: field `rope_theta` of struct `candle_transformers::models::mistral::Config` is private
--> src\model.rs:22:36
|
22 | .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
| ^^^^^^^^^^ private field
error[E0616]: field `num_hidden_layers` of struct `candle_transformers::models::mistral::Config` is private
--> src\model.rs:34:56
|
34 | kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
| ^^^^^^^^^^^^^^^^^ private field
error[E0061]: this function takes 2 arguments but 4 arguments were supplied
--> src\main.rs:379:25
|
379 | let model = Mistral::new(vb, &cache, &config, comm)?;
| ^^^^^^^^^^^^ -- ------ unexpected argument of type `&model::Cache`
| |
| unexpected argument of type `VarBuilderArgs<'_, ShardedSafeTensors>`
|
note: expected `VarBuilderArgs<'_, Box<dyn SimpleBackend>>`, found `Rc<Comm>`
--> src\main.rs:379:59
|
379 | let model = Mistral::new(vb, &cache, &config, comm)?;
| ^^^^
= note: expected struct `VarBuilderArgs<'_, Box<dyn SimpleBackend>>`
found struct `Rc<Comm>`
note: associated function defined here
--> C:\Users\igumn\.cargo\registry\src\index.crates.io-6f17d22bba15001f\candle-transformers-0.4.1\src\models\mistral.rs:383:12
|
383 | pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
| ^^^
help: remove the extra arguments
|
379 - let model = Mistral::new(vb, &cache, &config, comm)?;
379 + let model = Mistral::new(, &config, /* VarBuilderArgs<'_, Box<dyn SimpleBackend>> */)?;
|
warning: unused import: `candle_core::backend::BackendDevice`
--> src\main.rs:18:5
|
18 | use candle_core::backend::BackendDevice;
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default
warning: unused import: `candle_core::backend::BackendStorage`
--> src\model.rs:2:5
|
2 | use candle_core::backend::BackendStorage;
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Some errors have detailed explanations: E0061, E0616.
For more information about an error, try `rustc --explain E0061`.
warning: `ai-server` (bin "ai-server") generated 2 warnings
error: could not compile `ai-server` (bin "ai-server") due to 5 previous errors; 2 warnings emitted
[package]
name = "ai-server"
version = "0.1.0"
edition = "2021"
[dependencies]
candle-nn = "0.4.1"
candle-core = "0.4.1"
candle-datasets = "0.4.1"
candle-transformers = "0.4.1"
candle-examples = "0.4.1"
hf-hub = "0.3.2"
tokenizers = "0.15.2"
anyhow = "1.0.81"
clap = { version = "4.5.3", features = ["derive"] }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
axum = "0.7.5"
serde = { version = "1.0.197", features = ["derive"] }
tokio = "1.36.0"
once_cell = "1.19.0"
futures = "0.3.30"
cudarc = "0.10.0"
[build-dependencies]
bindgen_cuda = { version = "0.1.1", optional = true }
[features]
default = ["cuda"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda", "cudarc/nccl"]
model.rs
use candle_core::backend::BackendStorage;
use candle_core::{ DType, Device, Result, Tensor};
use std::sync::{Arc, Mutex};
const MAX_SEQ_LEN: usize = 4096;
pub type Config = candle_transformers::models::mistral::Config;
#[derive(Clone)]
pub struct Cache {
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor,
sin: Tensor,
}
impl Cache {
pub fn new(dtype: DType, config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
let n_elem = config.hidden_size / config.num_attention_heads;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
cos,
sin,
})
}
}
main.rs
use std::io::Write;
use std::rc::Rc;
use axum::{
routing::post,
routing::get,
Json, Router,
};
use serde::{Deserialize, Serialize};
use axum::{response::Html};
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::mistral::{Config, Model as Mistral};
use candle_core::{DType, Device, Tensor};
use candle_core::backend::BackendDevice;
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
use once_cell::sync::Lazy;
use futures::lock::Mutex;
use cudarc::nccl::safe::{Comm, Id};
static AI_SERVER: Lazy<Mutex<Option<TextGeneration>>> = Lazy::new(|| Mutex::new(None));
mod model;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
let ai = init(&args)?;
AI_SERVER.lock().await.replace(ai);
let app = Router::new().route("/", post(handle_request)).route("/", get(handler));
let listener = tokio::net::TcpListener::bind("127.0.0.1:8181")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
Ok(())
}
#[derive(Deserialize)]
struct Input {
doc1: String,
doc2: String,
}
#[derive(Serialize)]
struct Response {
result: String,
}
async fn handle_request(Json(payload): Json<Input>) -> Json<Response> {
let prompt = format!(r#"
{}
Which category does the text above belong to? Here is the list of categories:
{}
Give me a short answer. Just a category number. Category number:"#, payload.doc1, payload.doc2);
let mut ai_server_mut = AI_SERVER.lock().await;
let ai_server_opt = ai_server_mut.as_mut();
match ai_server_opt {
Some(ai_server) => {
let result_opt = ai_server.run(&prompt, 3);
match result_opt {
Ok(result) => {
Json(Response { result: result.replace("%","").replace("\n","") })
}
Err(e) => {
Json(Response { result: e.to_string() })
}
}
}
None => {
Json(Response { result: "AI server not initialized".to_string() })
}
}
}
async fn handler() -> Html<&'static str> {
Html("<h1>Server status: online</h1>")
}
struct TextGeneration {
model: Mistral,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Mistral,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<String> {
let mut result_string = String::new();
self.tokenizer.clear();
self.model.clear_kv_cache();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
println!("{}", prompt);
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("</s>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
result_string.push_str(&t);
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
result_string.push_str(&rest);
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(result_string)
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long)]
num_shards: usize,
#[arg(long)]
rank: Option<usize>,
#[arg(long)]
start_port: Option<usize>,
#[arg(long)]
cpu: bool,
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
temperature: Option<f64>,
#[arg(long)]
top_p: Option<f64>,
#[arg(long, default_value_t = 299792458)]
seed: u64,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn init(args: &Args) -> Result<TextGeneration> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
// let args = Args {
// cpu: false,
// tracing: false,
// use_flash_attn: false,
// temperature: None,
// top_p: None,
// seed: 299792458,
// model_id: None,
// revision: "main".to_string(),
// tokenizer_file: None,
// weight_files: None,
// repeat_penalty: 1.1,
// repeat_last_n: 3,
// };
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => {
"mistralai/Mistral-7B-v0.1".to_string()
}
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision.to_string(),
));
let tokenizer_filename = match &args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match &args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
};
println!("retrieved the files in {:?}", start.elapsed());
if args.rank.is_none() && args.start_port.is_some() {
let children: Vec<_> = (0..args.num_shards)
.map(|rank| {
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
args.push_back("--rank".to_string());
args.push_back(format!("{rank}"));
let name = args.pop_front().unwrap();
std::process::Command::new(name).args(args).spawn().unwrap()
})
.collect();
for mut child in children {
child.wait().unwrap();
}
return Err(E::msg("all children have exited"));
}
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn);
if args.start_port.is_some() {
let comm_file: String = "nccl_id.txt".to_string();
let comm_file = std::path::PathBuf::from(comm_file);
if comm_file.exists() {
return Err(E::msg("comm file already exists, please remove it first"));
}
let (model, device) = {
let num_shards = args.num_shards;
// Primitive IPC
let id = if args.rank.unwrap() == 0 {
let id = Id::new().unwrap();
let tmp_file = comm_file.with_extension(".comm.tgz");
std::fs::File::create(&tmp_file)?
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
std::fs::rename(&tmp_file, &comm_file)?;
id
} else {
while !comm_file.exists() {
std::thread::sleep(std::time::Duration::from_secs(1));
}
let data = std::fs::read(&comm_file)?;
let internal: [i8; 128] = data
.into_iter()
.map(|i| i as i8)
.collect::<Vec<_>>()
.try_into()
.unwrap();
let id: Id = Id::uninit(internal);
id
};
let device_arc = cudarc::driver::CudaDevice::new(args.rank.unwrap())?;
let device = Device::new_cuda(args.rank.unwrap())?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let cache = model::Cache::new(dtype, &config, &device)?;
let vb = unsafe {
candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
};
let comm = match Comm::from_rank(device_arc, args.rank.unwrap(), args.num_shards, id) {
Ok(comm) => Rc::new(comm),
Err(err) => anyhow::bail!("nccl error {:?}", err.0),
};
let model = Mistral::new(vb, &cache, &config, comm)?;
(model, device)
};
println!("loaded the model in {:?}", start.elapsed());
let pipeline = TextGeneration::new(
model,
tokenizer,
299792458,
Some(0.0),
Some(00.0),
1.1,
64,
&device,
);
Ok(pipeline)
} else {
let device = candle_examples::device(false)?;
let (model, device) = {
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Mistral::new(&config, vb)?;
(model, device)
};
println!("loaded the model in {:?}", start.elapsed());
let pipeline = TextGeneration::new(
model,
tokenizer,
299792458,
Some(0.0),
Some(00.0),
1.1,
64,
&device,
);
Ok(pipeline)
}
}
As I understand it, these errors can disappear if changes are made to the following files:
right?
I have: 4x RTX 3080 = 40GB total memory (each GPU by 10 GB memory)
I try to load model Mistral 7 about 15Gb file.
But I take error:
Is it possible to run on multiple GPU mode?