OpenNMT / OpenNMT-py

Open Source Neural Machine Translation and (Large) Language Models in PyTorch
https://opennmt.net/
MIT License
6.73k stars 2.25k forks source link

[Bug - Translation server] - Missing `tgt`param in `translator.translate` method (allows some multilingual/seq2seq models to work properly) #2586

Closed medfreeman closed 3 months ago

medfreeman commented 5 months ago

Some multilingual/seq2seq models such as M2M100 (c.f. Generation section in the linked page) require the bos_token set to the target language id in the sequence tgt property.
 In the case of the translation server, to be able to specify the requested translation language, we the need to directly manipulate the sequence tgt property prior to translation.

But in its current state the server has a disconnection between the sequence ref/ref_tok (which can be manipulated through tokenizers/processors btw) and tgt string prior to being sent to ctranslate2.

c.f. https://github.com/OpenNMT/OpenNMT-py/blob/cb1cb22b3de872434076067d316bff446af683ff/onmt/translate/translation_server.py#L588

Basically the parameter tgt of the self.translator.translate method is never provided.

c.f. https://github.com/OpenNMT/OpenNMT-py/blob/cb1cb22b3de872434076067d316bff446af683ff/onmt/translate/translation_server.py#L599

I successfully implemented a one-line patch that properly passes the parameter through and allows me to do multilingual translation.
It should not have side-effects on other type of models (for which the sequence ref is empty after tokenizing the sequence), by setting the parameter as an empty string in those cases.

Here’s the PR: #2585

Example of multilingual translation with a M2M100 model:

conf.json

{
    "models_root": "./available_models",
    "models": [
        {
            "id": 100,
            "model": "m2m-multi4-ft-ck945k/",
            "ct2_model": "m2m-multi4-ft-ck945k/",
            "load": true,
            "on_timeout": "unload",
            "ct2_translator_args": {
                "inter_threads": 4,
                "intra_threads": 2
            },
            "ct2_translate_batch_args": {},
            "opt": {
                "beam_size": 1,
                "batch_size": 8,
                "tgt_file_prefix": true
            },
            "preprocess": ["available_models.m2m-multi4-ft-ck945k.tokenizer.m2m100_tokenizer.preprocess"],
            "postprocess": ["available_models.m2m-multi4-ft-ck945k.tokenizer.m2m100_tokenizer.postprocess"]
        }
    ]
}

available_models/m2m-multi4-ft-ck945k/tokenizer/m2m100_tokenizer.py

import os
from pathlib import Path
from transformers import M2M100Tokenizer

cache = None

def loadTokenizer(model_root, logger):
        global cache
        if cache is not None:
              return cache

        model_path = os.path.join(model_root, "m2m-multi4-ft-ck945k/tokenizer/")
        logger.info("Loading m2m100 tokenizer from %s", model_path)
        cache = M2M100Tokenizer.from_pretrained(model_path)

        return cache

def preprocess(sequence, server_model):
        """Preprocess a single sequence.

        Args:
            sequence (dict[str, Unknown]): The sequence to preprocess.

        Returns:
            sequence (dict[str, Unknown]): The preprocessed sequence."""
        server_model.logger.info(f"Running preprocessor '{ Path(__file__).stem }'")

        ref = sequence.get("ref", None)
        if ref[0] is not None:
            server_model.logger.debug(f"${ref[0]=}")
            tgt_lang = ref[0].get("tgt_lang", None)
            if tgt_lang is not None:
                server_model.logger.debug(f"${tgt_lang=}")

                tokenizer = loadTokenizer(server_model.model_root, server_model.logger)

                seg = sequence.get("seg", None)
                tok = tokenizer.convert_ids_to_tokens(
                    tokenizer.encode(seg[0])
                )
                tok = " ".join(tok)

                sequence["seg"][0] = tok

                lang_prefix = f"__{tgt_lang}__"
                sequence["ref"][0] = f"{lang_prefix}"
                server_model.logger.info(f"Added lang prefix to ref: '{lang_prefix}'")
                server_model.logger.debug(f"${sequence['ref'][0]=}")

        return sequence

def postprocess(sequence, server_model):
        """Postprocess a single sequence.

        Args:
            sequence (dict[str, Unknown]): The sequence to postprocess.

        Returns:
            sequence (dict[str, Unknown]): The post processed sequence."""
        server_model.logger.info(f"Running postprocessor '{ Path(__file__).stem }'")

        tokenizer = loadTokenizer(server_model.model_root, server_model.logger)

        seg = sequence.get("seg", None)
        detok = tokenizer.decode(
            tokenizer.convert_tokens_to_ids(seg[0].split()[1:]),
            skip_special_tokens=True
        )
        return detok

Sample request to server:

[
    {
        "src": "Brian is in the kitchen.",
        "id": 100,
        "ref": {
            "src_lang": "en",
            "tgt_lang": "fr"
        }
    },
    {
        "src": "By the way, do you like to eat pancakes?",
        "id": 100,
        "ref": {
            "src_lang": "en",
            "tgt_lang": "fr"
        }
    }
]
vince62s commented 3 months ago

Please read the README of the project, we are no longer supporting OpenNMT-py and switching to https://github.com/eole-nlp/eole I suggest you to switch to eole if you intend to get support in the future. The server in eole is not ready yet but future devs will be done there. cheers.