Open wolf-li opened 1 year ago
The pipeline should not be slower than the Python equivalent on the same device. If you are using a CUDA-enabled GPU, please ensure it is used for both frameworks. The Marian model exposes a generate method via the MarianGenerator struct and the LanguageGenerator trait.
The pipeline should not be slower than the Python equivalent on the same device. If you are using a CUDA-enabled GPU, please ensure it is used for both frameworks. The Marian model exposes a generate method via the MarianGenerator struct and the LanguageGenerator trait.
Thanks for your replay. When the Marian model calls the pipeline using the GPU (specify the use of GPU Device::Cuda(3), observe that nvidia-smi is occupied when running rust programs), Slower than python calls to the huggingface library in docker environments without a GPU.
Python code (cpu: 0.3s Average of 100 visits)
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import MarianTokenizer, MarianMTModel
# Load the Marian model and tokenizer
model_name = "Helsinki-NLP/opus-mt-zh-en" # Replace with your desired model
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
app = FastAPI()
class InputData(BaseModel):
text: str
@app.post("/v1/predict", response_model=dict(generation_text=str))
async def predict(input_data: InputData):
# Translate the input text
input_text = input_data.text
input_ids = tokenizer.encode(input_text, return_tensors="pt")
translation_ids = model.generate(input_ids, max_length=50, num_return_sequences=1)
generation_text = tokenizer.decode(translation_ids[0], skip_special_tokens=True)
return {"generation_text": generation_text}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Rust code (cpu: 0.71s gpu: 0.38s Average of 100 visits)
extern crate anyhow;
use actix_web::{error, get, post, web,
http::{header::ContentType, StatusCode},
App, HttpResponse, Responder, Result, HttpRequest,HttpServer};
use serde::Serialize;
use serde::Deserialize;
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider, Resource, LocalResource};
use tch::Device;
use rust_bert::marian::{
MarianSourceLanguages,MarianTargetLanguages,
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use derive_more::{Display, Error};
use std::sync::{Arc, RwLock};
use anyhow::Error;
#[derive(Deserialize)]
struct Input {
text: String,
}
#[derive(Serialize)]
struct Output {
generation_text: String,
}
struct ModelFile {
config_resource:LocalResource,
weights: Arc<RwLock<Vec<u8>>>,
vocab_resource: LocalResource,
merges_resource: LocalResource,
}
impl ModelFile {
fn new(model_path:String, config_path:String, vocab_path:String, merges_path:String) -> Self {
let weights = Arc::new(RwLock::new(get_weights(model_path.clone()).unwrap()));
let config_resource = LocalResource { local_path: config_path.into(), };
let vocab_resource = LocalResource { local_path: vocab_path.into(), };
let merges_resource = LocalResource { local_path: merges_path.into(), };
Self {
weights,
config_resource,
vocab_resource,
merges_resource,
}
}
fn genertation(&self, input_context:&str) -> Result<impl Responder, MyError> {
let source_languages = MarianSourceLanguages::CHINESE2ENGLISH;
let target_languages = MarianTargetLanguages::CHINESE2ENGLISH;
let translation_config = TranslationConfig::new(
ModelType::Marian,
// ModelResource::Torch(Box::new(BufferResource { data: self.weights })),
ModelResource::Torch(Box::new(BufferResource { data: Arc::clone(&self.weights) })),
self.config_resource.clone(),
self.vocab_resource.clone(),
Some(self.merges_resource.clone()),
source_languages,
target_languages,
// Device::Cpu,
Device::Cuda(3),
);
let model = TranslationModel::new(translation_config).map_err(|e| {
MyError::ModelLoadError
})?;
// let output = model.translate(&[input_context.to_string()], None, None);
let output = model.translate(&[input_context.to_string()], None, None).map_err(|e| {
MyError::TranslateError
});
match output {
Ok(vec) => {
if let Some(first_element) = vec.get(0) {
Ok(web::Json(Output { generation_text: first_element.to_string(),
}))
}
else{
Err(MyError::TranslateError)
}
}
Err(error) => {
// Handle the error case
Err(MyError::TranslateError)
}
}
}
}
#[derive(Debug, Display, Error)]
enum MyError {
#[display(fmt = "translationModel load error")]
ModelLoadError,
#[display(fmt = "translate error")]
TranslateError,
}
impl error::ResponseError for MyError {
fn error_response(&self) -> HttpResponse {
HttpResponse::build(self.status_code())
.insert_header(ContentType::html())
.body(self.to_string())
}
fn status_code(&self) -> StatusCode {
match *self {
MyError::ModelLoadError => StatusCode::INTERNAL_SERVER_ERROR,
MyError::TranslateError => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
#[post("/v1/predict")]
async fn predicet_post(
info: web::Json<Input>,
appdata: web::Data<ModelFile>,
) -> Result<impl Responder, MyError> {
let result = appdata.genertation(&info.text);
result
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
let appdata = ModelFile::new(
"/root/.cache/.rustbert/opus-mt-zh-en/rust_model.ot".to_string(),
"/root/.cache/.rustbert/opus-mt-zh-en/config.json".to_string(),
"/root/.cache/.rustbert/opus-mt-zh-en/vocab.json".to_string(),
"/root/.cache/.rustbert/opus-mt-zh-en/source.spm".to_string(),
);
let appdata = web::Data::new(appdata);
HttpServer::new(move || {
App::new()
// .app_data(web::Data::clone(&appdata.clone()))
.app_data(web::Data::clone(&appdata))
.service(predicet_post)
// .service(index2)
})
.workers(4)
.bind(("0.0.0.0", 8090))?
.run()
.await
}
fn get_weights(model_path: String) -> anyhow::Result<Vec<u8>, anyhow::Error> {
Ok(std::fs::read(model_path)?)
}
Hello @wolf-li ,
Do you compile the code in release
mode with all optimizations?
Each time you call the generation
function, it will creat a new model (load from your disk and init
) , I think it would be the cause of it.
Using pipline is slower than using python huggingface library transformers generate function, when the model file is loaded, in using CPU envierment.