guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.58k stars 215 forks source link

any help with DeBERTa #406

Open traderpedroso opened 1 year ago

traderpedroso commented 1 year ago

I have made conversions of the model into two supported formats for Rust Bert: one being the OP extension, and the other, the newly supported ONNX format. Despite my endeavors, success has eluded me in the implementation phase. I endeavored to replicate a process which I had formerly accomplished with BART, but it proved unsuccessful. Could anyone kindly provide a preliminary guide or some assistance on how I might execute zero-shot learning with DeBERTa-v3?

from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli")
sequence_to_classify = "Angela Merkel is a politician in Germany and leader of the CDU"
candidate_labels = ["politics", "economy", "entertainment", "environment"]
output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
print(output)

Here follows my implementation using Facebook's BART Large. In this implementation, I have one endpoint in REST and another in ZeroMQ for stream, utilizing batch streaming.

extern crate actix_web;
extern crate anyhow;
extern crate serde;
extern crate serde_json;
extern crate zmq;

use actix_web::{web, App, HttpResponse, HttpServer, Responder};
use anyhow::Result;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::zero_shot_classification::{
    ZeroShotClassificationConfig, ZeroShotClassificationModel,
};
use rust_bert::resources::LocalResource;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::path::PathBuf;
use std::thread;
use tch::Device;
use zmq::{Context, SocketType};

#[derive(Deserialize, Serialize)]
struct MyData {
    sentence: Vec<String>,
    categories: Vec<String>,
}

fn generation_config(base_path: &str) -> ZeroShotClassificationConfig {
    let model_path = PathBuf::from(base_path.to_owned() + "rust_model.ot");
    let config_path = PathBuf::from(base_path.to_owned() + "config.json");
    let vocab_path = PathBuf::from(base_path.to_owned() + "vocab.json");
    let merges_path = PathBuf::from(base_path.to_owned() + "merges.txt");

    ZeroShotClassificationConfig {
        model_type: ModelType::Bart,
        model_resource: Box::new(LocalResource::from(model_path)),
        config_resource: Box::new(LocalResource::from(config_path)),
        vocab_resource: Box::new(LocalResource::from(vocab_path)),
        merges_resource: Some(Box::new(LocalResource::from(merges_path))),
        lower_case: false,
        strip_accents: None,
        add_prefix_space: None,
        device: Device::cuda_if_available(),
    }
}

fn error_response() -> Result<String, serde_json::Error> {
    let error_message = "Malformed JSON.";
    let example_json = r#"{
        "sentence": ["Who are you going to vote for in 2020?"],
        "categories": ["politics", "economy", "sports"]
    }"#;

    let error_response = json!({
        "error": error_message,
        "example": serde_json::from_str::<Value>(example_json)?
    });

    serde_json::to_string(&error_response)
}

fn predict(
    model: &ZeroShotClassificationModel,
    data: &[String],
    candidate_labels: &[&str],
    batch_size: usize,
) -> Vec<String> {
    let mut predictions = Vec::new();

    for batch_start in (0..data.len()).step_by(batch_size) {
        let batch_end = std::cmp::min(batch_start + batch_size, data.len());
        let batch = &data[batch_start..batch_end];

        let output = model
            .predict_multilabel(
                &batch.iter().map(AsRef::as_ref).collect::<Vec<&str>>(),
                candidate_labels,
                Some(Box::new(|label: &str| {
                    format!("This example is about {label}.")
                })),
                batch_size,
            )
            .unwrap();

        for item in output {
            let prediction = item
                .iter()
                .max_by(|a, b| a.score.partial_cmp(&b.score).unwrap())
                .unwrap()
                .text
                .clone();
            predictions.push(prediction);
        }
    }

    predictions
}

fn run_zeromq_server() -> Result<()> {
    let ctx = Context::new();
    let socket = ctx.socket(SocketType::REP)?;
    socket.bind("tcp://*:6044")?;

    let base_path = "/root/rustmodels/zeroshot/";

    let config = generation_config(base_path);

    let sequence_classification_model = ZeroShotClassificationModel::new(config)?;

    loop {
        let message = socket.recv_string(0)?.unwrap();
        let data_result: Result<HashMap<String, Vec<String>>, serde_json::Error> =
            serde_json::from_str(&message);

        match data_result {
            Ok(data) => {
                if let (Some(sentence), Some(candidate_labels)) = (data.get("sentence"), data.get("categories")) {
                    let candidate_labels: Vec<&str> = candidate_labels.iter().map(AsRef::as_ref).collect();

                    let predictions = predict(
                        &sequence_classification_model,
                        sentence,
                        &candidate_labels,
                        24,
                    );

                    let result = serde_json::to_string(&predictions)?;
                    socket.send(&result, 0)?;
                } else {
                    let error_json = error_response()?;
                    socket.send(&error_json, 0)?;
                }
            }
            Err(_) => {
                let error_json = error_response()?;
                socket.send(&error_json, 0)?;
            }
        }
    }
}

async fn handle_predict(data: web::Json<MyData>) -> impl Responder {
    let context = zmq::Context::new();
    let requester = context.socket(zmq::REQ).unwrap();
    requester.connect("tcp://127.0.0.1:6044").unwrap();

    let json_data = serde_json::to_string(&*data).unwrap();
if let Err(_) = requester.send(json_data.as_bytes(), 0) {
    return HttpResponse::InternalServerError().body(error_response().unwrap_or_else(|_| String::from("Internal server error.")));
}

let reply = match requester.recv_msg(0) {
    Ok(reply) => reply,
    Err(_) => return HttpResponse::InternalServerError().body(error_response().unwrap_or_else(|_| String::from("Internal server error."))),
};

let reply_str = std::str::from_utf8(&reply).unwrap().to_owned();

    HttpResponse::Ok().body(reply_str)
}

async fn run_http_server() -> std::io::Result<()> {
    HttpServer::new(|| App::new().route("/predict", web::post().to(handle_predict)))
        .bind("0.0.0.0:8081")?
        .run()
        .await
}

fn main() {
    let handle_zeromq = thread::spawn(|| {
        run_zeromq_server().expect("ZeroMQ server failed.");
    });

    let handle_http = thread::spawn(|| {
        let sys = actix_web::rt::System::new();
        sys.block_on(run_http_server()).expect("HTTP server failed.");
    });

    handle_zeromq.join().unwrap();
    handle_http.join().unwrap();
}
Philipp-Sc commented 10 months ago

@traderpedroso Have you made any progress?