guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.67k stars 217 forks source link

Question: How could I run custom Marian Model using rust-bert? #412

Closed wolf-li closed 1 year ago

wolf-li commented 1 year ago

Hey @guillaume-be, awesome job on this.

Want to ask how to load my fine-tuned model, the model file has been converted to ot. There is no way to call local model weights in the source code.

wolf-li commented 1 year ago

I write pipline translation task using local Marian model.

extern crate anyhow;

use std::sync::{Arc, RwLock};

use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider, Resource, LocalResource};
use tch::Device;
use rust_bert::marian::{
    MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
    MarianTargetLanguages, MarianVocabResources,
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};

fn main() -> anyhow::Result<()>  {
    let input_context_1 = ["你好"];
    // let input_context_2 = "世界";

    let weights = Arc::new(RwLock::new(get_weights()?));

    let model_resource = ModelResource::Torch(Box::new(BufferResource { data: weights.clone() }));

    let config_resource = LocalResource {
        local_path: "/root/.cache/.rustbert/opus-mt-zh-en/config.json".into(),
    };
    let vocab_resource = LocalResource {
        local_path: "/root/.cache/.rustbert/opus-mt-zh-en/vocab.json".into(),
    };
    let merges_resource = LocalResource {
        local_path: "/root/.cache/.rustbert/opus-mt-zh-en/source.spm".into(),
    };

    let source_languages = MarianSourceLanguages::CHINESE2ENGLISH;
    let target_languages = MarianTargetLanguages::CHINESE2ENGLISH;

    let translation_config = TranslationConfig::new(
        ModelType::Marian,
        model_resource,
        config_resource,
        vocab_resource,
        Some(merges_resource),
        source_languages,
        target_languages,
        Device::Cpu,
    );
    let model = TranslationModel::new(translation_config)?;

    let output = model.translate(&input_context_1, None, None)?;

    for sentence in output {
        println!("{sentence}");
    }

    Ok(())
}

fn get_weights() -> anyhow::Result<Vec<u8>, anyhow::Error> {
    Ok(std::fs::read("/root/.cache/.rustbert/opus-mt-zh-en/rust_model.ot")?)
}