jkawamoto / ctranslate2-rs

Rust bindings for OpenNMT/CTranslate2
https://docs.rs/ct2rs
MIT License
20 stars 3 forks source link

Add support for MarianMT Model #38

Closed solaoi closed 4 months ago

solaoi commented 5 months ago

Thank you for your great work! I would like to use this with the MarianMT Model as described here: https://opennmt.net/CTranslate2/guides/transformers.html#marianmt

I would appreciate your support on this.

jkawamoto commented 5 months ago

MarianMT models are compatible with the current implementation using the sentencepiece tokenizer. Below is a sample code:

use sentencepiece::SentencePieceProcessor;

use ct2rs::config::Config;
use ct2rs::translator::Translator;

let t = Translator::new("./data/opus-mt-en-jap", Config::default())?;
let encoder = SentencePieceProcessor::open("./data/opus-mt-en-jap/source.spm")?;
let decoder = SentencePieceProcessor::open("./data/opus-mt-en-jap/target.spm")?;

let source: Vec<String> = encoder.encode(
    "Hello world! This library provides Rust bindings for CTranslate2.",
)?.iter().map(|v| v.piece.to_string()).collect();

let res = t.translate_batch(&*vec![source], &*vec![vec![""]], &Default::default())?;
for r in res {
    if let Some(h) = r.hypotheses.get(0) {
        println!("{:?}", decoder.decode_pieces(h)?);
    }
}

Note that the model can be converted using the following command:

ct2-transformers-converter --model Helsinki-NLP/opus-mt-en-jap --output_dir data/opus-mt-en-jap

Additionally, ensure that source.spm and target.spm are copied from the repository to the directory data/opus-mt-en-jap.

I am considering providing detailed instructions; however, I encourage you to experiment with the code provided above.

solaoi commented 5 months ago

Thank you, it seems to work with the script you provided.

However, compared to when I run it in Python, it feels like the translation accuracy has degraded for all models based on MarianMT.

Do I need to set parameters like beam_size in TranslationOptions?

jkawamoto commented 5 months ago

I tried using the MarianMT model in Python, but the results were not good either. As you mentioned, the default translation options may not be optimal for this model.

Below is the code I used (taken from CTranslate2's docs):

import ctranslate2
import transformers

translator = ctranslate2.Translator("data/opus-mt-en-jap")
tokenizer = transformers.AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-jap")

source = tokenizer.convert_ids_to_tokens(tokenizer.encode("Hello world! This library provides Rust bindings for CTranslate2."))

results = translator.translate_batch([source])
target = results[0].hypotheses[0]

print(tokenizer.decode(tokenizer.convert_tokens_to_ids(target)))

The output I obtained was:

世界 は 上 を は き た. この 築 い た 荷 は, クミン を 描 く.

Could you share the specific parameters or settings you are using for the MarianMT models?

solaoi commented 5 months ago

I compared the following three MarianMT-based models by varying the beam_size and repetition_penalty, using the provided Python code and ctranslate2-rs:

Helsinki-NLP/opus-mt-en-jap

Python

"Hello world! This library provides Rust bindings for CTranslate2."

"Hitori Gotoh, also known as Bocchi-chan, is one of the main characters in the manga and anime series, Bocchi the Rock!. She is in the first year of Shuka High School and is in charge of the guitar and lyrics of the band, Kessoku Band."

ctranslate2-rs

"Hello world! This library provides Rust bindings for CTranslate2."

"Hitori Gotoh, also known as Bocchi-chan, is one of the main characters in the manga and anime series, Bocchi the Rock!. She is in the first year of Shuka High School and is in charge of the guitar and lyrics of the band, Kessoku Band."

Hoax0930/marian-finetuned-kde4-en-to-ja_kftt

Python

"Hello world! This library provides Rust bindings for CTranslate2."

"Hitori Gotoh, also known as Bocchi-chan, is one of the main characters in the manga and anime series, Bocchi the Rock!. She is in the first year of Shuka High School and is in charge of the guitar and lyrics of the band, Kessoku Band."

ctranslate2-rs

"Hello world! This library provides Rust bindings for CTranslate2."

"Hitori Gotoh, also known as Bocchi-chan, is one of the main characters in the manga and anime series, Bocchi the Rock!. She is in the first year of Shuka High School and is in charge of the guitar and lyrics of the band, Kessoku Band."

staka/fugumt-en-ja

Python

"Hello world! This library provides Rust bindings for CTranslate2."

"Hitori Gotoh, also known as Bocchi-chan, is one of the main characters in the manga and anime series, Bocchi the Rock!. She is in the first year of Shuka High School and is in charge of the guitar and lyrics of the band, Kessoku Band."

ctranslate2-rs

"Hello world! This library provides Rust bindings for CTranslate2."

"Hitori Gotoh, also known as Bocchi-chan, is one of the main characters in the manga and anime series, Bocchi the Rock!. She is in the first year of Shuka High School and is in charge of the guitar and lyrics of the band, Kessoku Band."

jkawamoto commented 5 months ago

Thank you for sharing the data. I also tested the models you mentioned and discovered that the sentencepiece tokenizer does not append </s> to the end of the token list. When I manually add </s>, the output no longer contains repetitive text. Additionally, I have implemented a modification that omits the target prefix if the model does not require it (#40). With these updates, the outputs are more similar to those produced by Python.

Here is my newest code, btw:

use sentencepiece::SentencePieceProcessor;

use ct2rs::config::Config;
use ct2rs::translator::Translator;

let t = Translator::new("./data/opus-mt-en-jap", Config::default())?;
let encoder = SentencePieceProcessor::open("./data/opus-mt-en-jap/source.spm")?;
let decoder = SentencePieceProcessor::open("./data/opus-mt-en-jap/target.spm")?;

let mut source: Vec<String> = encoder.encode(
    "Hello world! This library provides Rust bindings for CTranslate2.",
)?.iter().map(|v| v.piece.to_string()).collect();
source.push("</s>".to_string());

let res = t.translate_batch(vec![source], &Default::default())?;
for r in res {
    if let Some(h) = r.hypotheses.get(0) {
        println!("{:?}", decoder.decode_pieces(h)?);
    }
}
solaoi commented 5 months ago

Thank you! Everything worked perfectly with the main branch. The code you provided had just one typo below, but everything else was perfect!

I really appreciate your prompt response. I'm looking forward to future updates!

- let res = t.translate_batch(vec![source], &Default::default())?;
+ let res = t.translate_batch(&vec![source], &Default::default())?;
solaoi commented 4 months ago

@jkawamoto Thank you for adding the sample document below. https://github.com/jkawamoto/ctranslate2-rs/blob/main/examples/marian-mt.rs

I tried using version 0.7.3, but it seems that Translator::with_tokenizer does not exist.

It works when I use new as shown below. Is there something wrong with my implementation?

use ct2rs::config::Config;
use ct2rs::sentencepiece::Tokenizer;
use ct2rs::TranslationOptions;
use ct2rs::Translator;

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let text = std::env::args()
        .nth(1)
        .unwrap_or("Hello world! This library provides Rust bindings for CTranslate2.".to_string());
    let model_path = "./mymodel";
    let t = Translator::new(
        &model_path,
        Tokenizer::new(&model_path)?,
        &Config::default(),
    )?;
    let sources: Vec<String> = text.lines().map(String::from).collect();

    let res = t.translate_batch(
        &sources,
        &TranslationOptions {
            beam_size: 5,
            ..Default::default()
        },
    )?;
    for (r, _) in res {
        print!("{}", r);
    }
    Ok(())
}
jkawamoto commented 4 months ago

Your code looks good with v0.7.3.

Please refer to the example at v0.7.3 instead of the one on the main branch. The main branch is currently aimed at v0.8.0, which includes some breaking changes.

solaoi commented 4 months ago

Thank you for your quick response! I’ll check the v0.7.3 example as suggested.