guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.6k stars 215 forks source link

Question: Configuring ZeroShotClassificationModel with DeBERTaV2 - Documentation #433

Open Philipp-Sc opened 10 months ago

Philipp-Sc commented 10 months ago

Hi @jondot @guillaume-be

I work on a llm-fraud-detection library using rust-bert and llama.cpp.

In the process of updating the repo, I like to test DeBERTaV2 for my zero shot classification task.

Currently I am using BERT which is very simple to setup due to the provided Default implementation.

ZeroShotClassificationModel::new(Default::default();

See

Do I need to download the model on my own and convert it to rust and provide the paths like this? See #406

fn generation_config(base_path: &str) -> ZeroShotClassificationConfig {
    let model_path = PathBuf::from(base_path.to_owned() + "rust_model.ot");
    let config_path = PathBuf::from(base_path.to_owned() + "config.json");
    let vocab_path = PathBuf::from(base_path.to_owned() + "vocab.json");
    let merges_path = PathBuf::from(base_path.to_owned() + "merges.txt");

    ZeroShotClassificationConfig {
        model_type: ModelType::DeBERTaV2,
        model_resource: Box::new(LocalResource::from(model_path)),
        config_resource: Box::new(LocalResource::from(config_path)),
        vocab_resource: Box::new(LocalResource::from(vocab_path)),
        merges_resource: Some(Box::new(LocalResource::from(merges_path))),
        lower_case: false,
        strip_accents: None,
        add_prefix_space: None,
        device: Device::cuda_if_available(),
    }
}

Can you point me in the right direction on how to setup the ZeroShotClassificationModel for DeBERTaV2?

This might be a good time to add an example for this, this would likely help #425 as well. (I would be happy to propose a PR once I got the hang of it)

Thanks in advance.

Best regards, Philipp-Sc

mikkel1156 commented 8 months ago

In my case it was because of wrong vocab file, found out that the spm.model file is the vocab file for the model that I'm using.

However this was with ONNX backend (needs the feature enabled). The model I used is https://huggingface.co/Xenova/nli-deberta-v3-large

use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationConfig;
use rust_bert::resources::LocalResource;
use std::path::PathBuf;

fn model(base_path: &str) {
    let model_path = PathBuf::from(base_path.to_owned() + "onnx/model.onnx");
    let config_path = PathBuf::from(base_path.to_owned() + "config.json");
    let vocab_path = PathBuf::from(base_path.to_owned() + "spm.model");

    let classification_model = ZeroShotClassificationModel::new(ZeroShotClassificationConfig::new(
        ModelType::DebertaV2,
        ModelResource::ONNX(ONNXModelResources {
            encoder_resource: Some(Box::new(LocalResource::from(model_path))),
            ..Default::default()
        }),
        LocalResource::from(config_path),
        LocalResource::from(vocab_path),
        None,
        false,
        None,
        None,
    )).expect("could not create zero_shot_classification model");
    let input = ["Who are you voting for in 2020?", "The prime minister has announced a stimulus package which was widely criticized by the opposition."];
    let labels = &["politics", "public health", "economics", "sports"];
    let output = classification_model.predict_multilabel(
        &input,
        labels,
        None,
        128
    );
    println!("{:?}", output);
}