I have made conversions of the model into two supported formats for Rust Bert: one being the OP extension, and the other, the newly supported ONNX format. Despite my endeavors, success has eluded me in the implementation phase. I endeavored to replicate a process which I had formerly accomplished with BART, but it proved unsuccessful. Could anyone kindly provide a preliminary guide or some assistance on how I might execute zero-shot learning with DeBERTa-v3?
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli")
sequence_to_classify = "Angela Merkel is a politician in Germany and leader of the CDU"
candidate_labels = ["politics", "economy", "entertainment", "environment"]
output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
print(output)
Here follows my implementation using Facebook's BART Large. In this implementation, I have one endpoint in REST and another in ZeroMQ for stream, utilizing batch streaming.
extern crate actix_web;
extern crate anyhow;
extern crate serde;
extern crate serde_json;
extern crate zmq;
use actix_web::{web, App, HttpResponse, HttpServer, Responder};
use anyhow::Result;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::zero_shot_classification::{
ZeroShotClassificationConfig, ZeroShotClassificationModel,
};
use rust_bert::resources::LocalResource;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::path::PathBuf;
use std::thread;
use tch::Device;
use zmq::{Context, SocketType};
#[derive(Deserialize, Serialize)]
struct MyData {
sentence: Vec<String>,
categories: Vec<String>,
}
fn generation_config(base_path: &str) -> ZeroShotClassificationConfig {
let model_path = PathBuf::from(base_path.to_owned() + "rust_model.ot");
let config_path = PathBuf::from(base_path.to_owned() + "config.json");
let vocab_path = PathBuf::from(base_path.to_owned() + "vocab.json");
let merges_path = PathBuf::from(base_path.to_owned() + "merges.txt");
ZeroShotClassificationConfig {
model_type: ModelType::Bart,
model_resource: Box::new(LocalResource::from(model_path)),
config_resource: Box::new(LocalResource::from(config_path)),
vocab_resource: Box::new(LocalResource::from(vocab_path)),
merges_resource: Some(Box::new(LocalResource::from(merges_path))),
lower_case: false,
strip_accents: None,
add_prefix_space: None,
device: Device::cuda_if_available(),
}
}
fn error_response() -> Result<String, serde_json::Error> {
let error_message = "Malformed JSON.";
let example_json = r#"{
"sentence": ["Who are you going to vote for in 2020?"],
"categories": ["politics", "economy", "sports"]
}"#;
let error_response = json!({
"error": error_message,
"example": serde_json::from_str::<Value>(example_json)?
});
serde_json::to_string(&error_response)
}
fn predict(
model: &ZeroShotClassificationModel,
data: &[String],
candidate_labels: &[&str],
batch_size: usize,
) -> Vec<String> {
let mut predictions = Vec::new();
for batch_start in (0..data.len()).step_by(batch_size) {
let batch_end = std::cmp::min(batch_start + batch_size, data.len());
let batch = &data[batch_start..batch_end];
let output = model
.predict_multilabel(
&batch.iter().map(AsRef::as_ref).collect::<Vec<&str>>(),
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {label}.")
})),
batch_size,
)
.unwrap();
for item in output {
let prediction = item
.iter()
.max_by(|a, b| a.score.partial_cmp(&b.score).unwrap())
.unwrap()
.text
.clone();
predictions.push(prediction);
}
}
predictions
}
fn run_zeromq_server() -> Result<()> {
let ctx = Context::new();
let socket = ctx.socket(SocketType::REP)?;
socket.bind("tcp://*:6044")?;
let base_path = "/root/rustmodels/zeroshot/";
let config = generation_config(base_path);
let sequence_classification_model = ZeroShotClassificationModel::new(config)?;
loop {
let message = socket.recv_string(0)?.unwrap();
let data_result: Result<HashMap<String, Vec<String>>, serde_json::Error> =
serde_json::from_str(&message);
match data_result {
Ok(data) => {
if let (Some(sentence), Some(candidate_labels)) = (data.get("sentence"), data.get("categories")) {
let candidate_labels: Vec<&str> = candidate_labels.iter().map(AsRef::as_ref).collect();
let predictions = predict(
&sequence_classification_model,
sentence,
&candidate_labels,
24,
);
let result = serde_json::to_string(&predictions)?;
socket.send(&result, 0)?;
} else {
let error_json = error_response()?;
socket.send(&error_json, 0)?;
}
}
Err(_) => {
let error_json = error_response()?;
socket.send(&error_json, 0)?;
}
}
}
}
async fn handle_predict(data: web::Json<MyData>) -> impl Responder {
let context = zmq::Context::new();
let requester = context.socket(zmq::REQ).unwrap();
requester.connect("tcp://127.0.0.1:6044").unwrap();
let json_data = serde_json::to_string(&*data).unwrap();
if let Err(_) = requester.send(json_data.as_bytes(), 0) {
return HttpResponse::InternalServerError().body(error_response().unwrap_or_else(|_| String::from("Internal server error.")));
}
let reply = match requester.recv_msg(0) {
Ok(reply) => reply,
Err(_) => return HttpResponse::InternalServerError().body(error_response().unwrap_or_else(|_| String::from("Internal server error."))),
};
let reply_str = std::str::from_utf8(&reply).unwrap().to_owned();
HttpResponse::Ok().body(reply_str)
}
async fn run_http_server() -> std::io::Result<()> {
HttpServer::new(|| App::new().route("/predict", web::post().to(handle_predict)))
.bind("0.0.0.0:8081")?
.run()
.await
}
fn main() {
let handle_zeromq = thread::spawn(|| {
run_zeromq_server().expect("ZeroMQ server failed.");
});
let handle_http = thread::spawn(|| {
let sys = actix_web::rt::System::new();
sys.block_on(run_http_server()).expect("HTTP server failed.");
});
handle_zeromq.join().unwrap();
handle_http.join().unwrap();
}
I have made conversions of the model into two supported formats for Rust Bert: one being the OP extension, and the other, the newly supported ONNX format. Despite my endeavors, success has eluded me in the implementation phase. I endeavored to replicate a process which I had formerly accomplished with BART, but it proved unsuccessful. Could anyone kindly provide a preliminary guide or some assistance on how I might execute zero-shot learning with DeBERTa-v3?
Here follows my implementation using Facebook's BART Large. In this implementation, I have one endpoint in REST and another in ZeroMQ for stream, utilizing batch streaming.