Closed wolf-li closed 1 year ago
I write pipline translation task using local Marian model.
extern crate anyhow;
use std::sync::{Arc, RwLock};
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider, Resource, LocalResource};
use tch::Device;
use rust_bert::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
fn main() -> anyhow::Result<()> {
let input_context_1 = ["你好"];
// let input_context_2 = "世界";
let weights = Arc::new(RwLock::new(get_weights()?));
let model_resource = ModelResource::Torch(Box::new(BufferResource { data: weights.clone() }));
let config_resource = LocalResource {
local_path: "/root/.cache/.rustbert/opus-mt-zh-en/config.json".into(),
};
let vocab_resource = LocalResource {
local_path: "/root/.cache/.rustbert/opus-mt-zh-en/vocab.json".into(),
};
let merges_resource = LocalResource {
local_path: "/root/.cache/.rustbert/opus-mt-zh-en/source.spm".into(),
};
let source_languages = MarianSourceLanguages::CHINESE2ENGLISH;
let target_languages = MarianTargetLanguages::CHINESE2ENGLISH;
let translation_config = TranslationConfig::new(
ModelType::Marian,
model_resource,
config_resource,
vocab_resource,
Some(merges_resource),
source_languages,
target_languages,
Device::Cpu,
);
let model = TranslationModel::new(translation_config)?;
let output = model.translate(&input_context_1, None, None)?;
for sentence in output {
println!("{sentence}");
}
Ok(())
}
fn get_weights() -> anyhow::Result<Vec<u8>, anyhow::Error> {
Ok(std::fs::read("/root/.cache/.rustbert/opus-mt-zh-en/rust_model.ot")?)
}
Hey @guillaume-be, awesome job on this.
Want to ask how to load my fine-tuned model, the model file has been converted to ot. There is no way to call local model weights in the source code.