EricLBuehler / mistral.rs

Blazingly fast LLM inference.
MIT License
4.41k stars 308 forks source link

[Not really a FR] Not recreating channel inside the loop #784

Closed rickbeeloo closed 1 month ago

rickbeeloo commented 1 month ago

I was looking at the batch example and noticed that the channel get recreated for every request:

 for _ in 0..n_requests {
        let (tx, rx) = channel(10_000);
        ....

In tokio normally the channel is created and requests are pushed to it this confused me at first. Maybe an example like this would be more intuitive:

use mistralrs::{
    initialize_logging, ChatCompletionResponse, Constraint, Device, DeviceMapMetadata,
    GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder,
    ModelDType, NormalRequest, PagedAttentionConfig, Request, RequestMessage, Response,
    SamplingParams, SchedulerConfig, TokenSource, Usage, NormalLoaderBuilder, NormalSpecificConfig, 
    ResponseOk
};

async fn setup() -> anyhow::Result<Arc<MistralRs>> {
    // Select a Mistral model
    // This uses a model, tokenizer, and chat template, from HF hub.

    let loader = NormalLoaderBuilder::new(
        NormalSpecificConfig {
            use_flash_attn: true,
            prompt_batchsize: None,
            topology: None,
            organization: Default::default(),
            write_uqff: None,
            from_uqff: None,
        },
        None, 
        None,
        Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
    )
    .build(None)?;

    // Load, into a Pipeline
    let pipeline = loader.load_model_from_hf(
        None,
        TokenSource::CacheToken,
        &ModelDType::default(),
        &Device::cuda_if_available(0)?,
        false,
        DeviceMapMetadata::dummy(),
        None,
        Some(PagedAttentionConfig::new(
            Some(32),
            500,
            MemoryGpuConfig::Utilization(0.9),
        )?), // No PagedAttention.
    )?;
    let config = pipeline
        .lock()
        .await
        .get_metadata()
        .cache_config
        .as_ref()
        .unwrap()
        .clone();
    // Create the MistralRs, which is a runner
    Ok(MistralRsBuilder::new(
        pipeline,
        SchedulerConfig::PagedAttentionMeta {
            max_num_seqs: 500,
            config,
        },
    )
    .with_throughput_logging()
    .build())
}

async fn bench_mistralrs(n_requests: usize) -> anyhow::Result<()> {
    initialize_logging();
    let mistralrs = setup().await?;

    // Create a single channel
    let (tx, rx) = channel::<Response>(n_requests);

    // Send all requests
    for i in 0..n_requests {
        let request = create_request(i, tx.clone());
        mistralrs.get_sender()?.send(request).await?;
    }

    // Collect responses
    let responses = collect_responses(rx, n_requests).await;

    // Process results
    process_results(responses);

    Ok(())
}

fn create_request(id: usize, tx: Sender<Response>) -> Request {
    Request::Normal(NormalRequest {
        messages: RequestMessage::Chat(vec![IndexMap::from([
            ("role".to_string(), Either::Left("user".to_string())),
            (
                "content".to_string(),
                Either::Left("What is 1+1?".to_string()),
            ),
        ])]),
        sampling_params: SamplingParams::default(),
        response: tx,
        return_logprobs: false,
        is_streaming: false,
        id,
        constraint: Constraint::None,
        suffix: None,
        adapters: None,
        tools: None,
        tool_choice: None,
        logits_processors: None,
    })
}

async fn collect_responses(mut rx: Receiver<Response>, n_requests: usize) -> Vec<Response> {
    let mut responses = Vec::with_capacity(n_requests);
    for _ in 0..n_requests {
        if let Some(response) = rx.recv().await {
            responses.push(response);
        }
    }
    responses
}

fn process_results(responses: Vec<Response>) {
    let mut max_prompt = f32::MIN;
    let mut max_completion = f32::MIN;

    for response in responses {
        match response.as_result() {
            Ok(result) => {
                match result {
                    ResponseOk::Done(c) =>  {
                        max_prompt = max_prompt.max(c.usage.avg_prompt_tok_per_sec);
                        max_completion = max_completion.max(c.usage.avg_compl_tok_per_sec);
                    },
                    _ => unreachable!()
                }
            },
            Err(e) =>  unreachable!()
        }
    }
    println!("Max stats: {} prompt tokens/s, {} completion tokens/s", max_prompt, max_completion);
    }

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    bench_mistralrs(10).await;
    Ok(())
}

I think now the part:

async fn bench_mistralrs(n_requests: usize) -> anyhow::Result<()> {
    initialize_logging();
    let mistralrs = setup().await?;

    // Create a single channel
    let (tx, rx) = channel::<Response>(n_requests);

    // Send all requests
    for i in 0..n_requests {
        let request = create_request(i, tx.clone());
        mistralrs.get_sender()?.send(request).await?;
    }

    // Collect responses
    let responses = collect_responses(rx, n_requests).await;

    // Process results
    process_results(responses);

    Ok(())
}

Matches more with the tokio example:

async fn main() {
    let (tx, mut rx) = mpsc::channel(100);

    tokio::spawn(async move {
        for i in 0..10 {
            if let Err(_) = tx.send(i).await {
                println!("receiver dropped");
                return;
            }
        }
    });

    while let Some(i) = rx.recv().await {
        println!("got = {}", i);
    }
}
rickbeeloo commented 1 month ago

I wonder how it works behind the scenes for Mistral. If you recreate the channel for each iteration like in the example does that then load all requests in memory or still max 10,000?

rickbeeloo commented 1 month ago

Let me just close this as it's not really a necessary thing but people can find it back

EricLBuehler commented 1 month ago

@rickbeeloo I think you are correct. Perhaps using an mspc channel is overkill.

All we need, internally, is way to send the request return the response across threads. Internally, the Engine runs on a different thread. If you have any ideas about how to improve this I would accept a PR!