guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.6k stars 215 forks source link

Does the marian model have a method like huggingface generate? #414

Open wolf-li opened 1 year ago

wolf-li commented 1 year ago

Using pipline is slower than using python huggingface library transformers generate function, when the model file is loaded, in using CPU envierment.

guillaume-be commented 1 year ago

The pipeline should not be slower than the Python equivalent on the same device. If you are using a CUDA-enabled GPU, please ensure it is used for both frameworks. The Marian model exposes a generate method via the MarianGenerator struct and the LanguageGenerator trait.

wolf-li commented 1 year ago

The pipeline should not be slower than the Python equivalent on the same device. If you are using a CUDA-enabled GPU, please ensure it is used for both frameworks. The Marian model exposes a generate method via the MarianGenerator struct and the LanguageGenerator trait.

Thanks for your replay. When the Marian model calls the pipeline using the GPU (specify the use of GPU Device::Cuda(3), observe that nvidia-smi is occupied when running rust programs), Slower than python calls to the huggingface library in docker environments without a GPU.

Python code (cpu: 0.3s Average of 100 visits)

from fastapi import FastAPI
from pydantic import BaseModel
from transformers import MarianTokenizer, MarianMTModel
# Load the Marian model and tokenizer
model_name = "Helsinki-NLP/opus-mt-zh-en"  # Replace with your desired model
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)

app = FastAPI()

class InputData(BaseModel):
    text: str

@app.post("/v1/predict", response_model=dict(generation_text=str))
async def predict(input_data: InputData):
    # Translate the input text
    input_text = input_data.text
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    translation_ids = model.generate(input_ids, max_length=50, num_return_sequences=1)
    generation_text = tokenizer.decode(translation_ids[0], skip_special_tokens=True)

    return {"generation_text": generation_text}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

Rust code (cpu: 0.71s gpu: 0.38s Average of 100 visits)

extern crate anyhow;
use actix_web::{error, get, post, web, 
    http::{header::ContentType, StatusCode},
    App, HttpResponse, Responder, Result, HttpRequest,HttpServer};
use serde::Serialize;
use serde::Deserialize;
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider, Resource, LocalResource};
use tch::Device;
use rust_bert::marian::{
    MarianSourceLanguages,MarianTargetLanguages, 
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use derive_more::{Display, Error};
use std::sync::{Arc, RwLock};
use anyhow::Error;

#[derive(Deserialize)]
struct Input {
    text: String,
}

#[derive(Serialize)]
struct Output {
    generation_text: String,
}

struct ModelFile {
    config_resource:LocalResource,
    weights: Arc<RwLock<Vec<u8>>>,
    vocab_resource: LocalResource,
    merges_resource: LocalResource,
}

impl ModelFile {
    fn new(model_path:String, config_path:String, vocab_path:String, merges_path:String) -> Self  {
        let weights =  Arc::new(RwLock::new(get_weights(model_path.clone()).unwrap()));

        let config_resource = LocalResource { local_path: config_path.into(), };
        let vocab_resource = LocalResource { local_path: vocab_path.into(), };
        let merges_resource = LocalResource { local_path: merges_path.into(), };
        Self {
            weights,
            config_resource,
            vocab_resource,
            merges_resource,
        }
    }

    fn genertation(&self, input_context:&str) -> Result<impl Responder, MyError> {

        let source_languages = MarianSourceLanguages::CHINESE2ENGLISH;
        let target_languages = MarianTargetLanguages::CHINESE2ENGLISH;

        let translation_config = TranslationConfig::new(
            ModelType::Marian,
            // ModelResource::Torch(Box::new(BufferResource { data: self.weights })),
            ModelResource::Torch(Box::new(BufferResource { data: Arc::clone(&self.weights) })),
            self.config_resource.clone(),
            self.vocab_resource.clone(),
            Some(self.merges_resource.clone()),
            source_languages,
            target_languages,
            // Device::Cpu,
            Device::Cuda(3),
        );

        let model = TranslationModel::new(translation_config).map_err(|e| {
            MyError::ModelLoadError
        })?;

        // let output = model.translate(&[input_context.to_string()], None, None);
        let output = model.translate(&[input_context.to_string()], None, None).map_err(|e| {
            MyError::TranslateError
        });

        match output {
            Ok(vec) => {
                if let Some(first_element) = vec.get(0) {
                    Ok(web::Json(Output { generation_text: first_element.to_string(),
                    }))
                } 
                else{
                    Err(MyError::TranslateError)
                }
            }
            Err(error) => {
                // Handle the error case
                Err(MyError::TranslateError)
            }
        }

    }

}

#[derive(Debug, Display, Error)]
enum MyError {
    #[display(fmt = "translationModel load error")]
    ModelLoadError,

    #[display(fmt = "translate error")]
    TranslateError,
}

impl error::ResponseError for MyError {
    fn error_response(&self) -> HttpResponse {
        HttpResponse::build(self.status_code())
            .insert_header(ContentType::html())
            .body(self.to_string())
    }

    fn status_code(&self) -> StatusCode {
        match *self {
            MyError::ModelLoadError => StatusCode::INTERNAL_SERVER_ERROR,
            MyError::TranslateError => StatusCode::INTERNAL_SERVER_ERROR,
        }
    }
}

#[post("/v1/predict")]
async fn predicet_post(
    info: web::Json<Input>,
    appdata: web::Data<ModelFile>,
) -> Result<impl Responder, MyError> {
    let result = appdata.genertation(&info.text);
    result
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {

    let appdata = ModelFile::new(
        "/root/.cache/.rustbert/opus-mt-zh-en/rust_model.ot".to_string(),
        "/root/.cache/.rustbert/opus-mt-zh-en/config.json".to_string(),
        "/root/.cache/.rustbert/opus-mt-zh-en/vocab.json".to_string(),
        "/root/.cache/.rustbert/opus-mt-zh-en/source.spm".to_string(),
    );

    let appdata = web::Data::new(appdata);

    HttpServer::new(move || {
        App::new()
            // .app_data(web::Data::clone(&appdata.clone()))
            .app_data(web::Data::clone(&appdata))
            .service(predicet_post)
            // .service(index2)
    })
    .workers(4)
    .bind(("0.0.0.0", 8090))?
    .run()
    .await
}

fn get_weights(model_path: String) -> anyhow::Result<Vec<u8>, anyhow::Error> {
    Ok(std::fs::read(model_path)?)
}
guillaume-be commented 11 months ago

Hello @wolf-li ,

Do you compile the code in release mode with all optimizations?

linkedlist771 commented 5 months ago

Each time you call the generation function, it will creat a new model (load from your disk and init) , I think it would be the cause of it.