Helsinki-NLP / Tatoeba-Challenge

Other
808 stars 91 forks source link

Eng to Chinese translation weird behaviour when translating single word #13

Open LeeDustin opened 3 years ago

LeeDustin commented 3 years ago

I am using Marian translation via transformers library in python, most of the time it works great, but in certain cases it takes a long time and returns weird results.

from transformers import MarianMTModel, MarianTokenizer
print("loading model...")
model_name = 'Helsinki-NLP/opus-mt-en-zh'
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
print("finished loading model")

def HF_translate(word):
    src_text = ['>>cmn_Hans<< '+word]
    translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
    return tokenizer.decode(translated[0], skip_special_tokens=True)

while True:
   word = input("input word to be translated: ")
   translated_word = HF_translate(word)
   print("translated_word: ", translated_word)

Most of the time it works fine and returns within a second, for example:

input word to be translated: commission
translated_word:  佣金

But if I enter certain words it takes more than 30 seconds and give a weird result:

input word to be translated: dog
translated_word:  狗狗 狗狗狗 狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗狗

where the correct translation should be just a single "狗"

Any idea what is wrong or if there are any ways to handle exceptional cases?

jorgtied commented 3 years ago

Sorry, no I don't know. Do you know if the same happens with the native MarianNMT model?

LeeDustin commented 3 years ago

I have not tried on native Marian model, as I have been using Python and huggingface mainly. Is there a major difference between the two?

zhongkaifu commented 3 years ago

You could try this: Seq2SeqSharp Release Package It also included English-Chinese translation model.

iDustPan commented 3 years ago

this bug seems to only repredoce on ">>cmn_Hans<<", won't repreduce on ">>cmn_Hant<<".

iDustPan commented 3 years ago

You could try this: Seq2SeqSharp Release Package It also included English-Chinese translation model.

@zhongkaifu Can you give a demo for showing us how to use your model to translation?

zhongkaifu commented 3 years ago

@iDustPan For demo, you can download release package zip files from above link, unzip it and then call "test_enu_chs.bat" file. It automatically calls SentencePiece encoder to tokenize input sequence to tokens, and then call Seq2SeqConsole for translation, and then call SentencePiece decoder to join translated tokens back to output sequence.

For code, you can look Seq2SeqConsole project in Seq2SeqSharp repo. I copy and paste some code here which include functions for training, validation and test. Let me know if you have any additional questions.

``c# //Parse command line ArgParser argParser = new ArgParser(args, opts); if (string.IsNullOrEmpty(opts.ConfigFilePath) == false) { Console.WriteLine($"Loading config file from '{opts.ConfigFilePath}'"); opts = JsonConvert.DeserializeObject(File.ReadAllText(opts.ConfigFilePath));
}

            Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{opts.Task}_{Utils.GetTimeStamp(DateTime.Now)}.log";
            ShowOptions(args, opts);

            Seq2Seq ss = null;
            ModeEnums mode = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.Task);
            ShuffleEnums shuffleType = (ShuffleEnums)Enum.Parse(typeof(ShuffleEnums), opts.ShuffleType);
            TooLongSequence tooLongSequence = (TooLongSequence)Enum.Parse(typeof(TooLongSequence), opts.TooLongSequence);

            if (mode == ModeEnums.Train)
            {
                // Load train corpus
                Seq2SeqCorpus trainCorpus = new Seq2SeqCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, batchSize: opts.BatchSize, shuffleBlockSize: opts.ShuffleBlockSize,
                    maxSrcSentLength: opts.MaxTrainSrcSentLength, maxTgtSentLength: opts.MaxTrainTgtSentLength, shuffleEnums: shuffleType, tooLongSequence: tooLongSequence);

                // Load valid corpus
                List<Seq2SeqCorpus> validCorpusList = new List<Seq2SeqCorpus>();
                if (String.IsNullOrEmpty(opts.ValidCorpusPaths) == false)
                {
                    string[] validCorpusPathList = opts.ValidCorpusPaths.Split(';');
                    foreach (var validCorpusPath in validCorpusPathList)
                    {
                        validCorpusList.Add(new Seq2SeqCorpus(validCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxTestSrcSentLength, opts.MaxTestTgtSentLength, shuffleEnums: shuffleType, tooLongSequence: tooLongSequence));
                    }

                }

                // Create learning rate
                ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                // Create optimizer
                IOptimizer optimizer = Misc.CreateOptimizer(opts);

                // Create metrics
                List<IMetric> metrics = CreateMetrics();

                if (!String.IsNullOrEmpty(opts.ModelFilePath) && File.Exists(opts.ModelFilePath))
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                    ss = new Seq2Seq(opts);
                }
                else
                {
                    // Load or build vocabulary
                    Vocab srcVocab = null;
                    Vocab tgtVocab = null;
                    if (!string.IsNullOrEmpty(opts.SrcVocab) && !string.IsNullOrEmpty(opts.TgtVocab))
                    {
                        Logger.WriteLine($"Loading source vocabulary from '{opts.SrcVocab}' and target vocabulary from '{opts.TgtVocab}'. Shared vocabulary is '{opts.SharedEmbeddings}'");
                        if (opts.SharedEmbeddings == true && (opts.SrcVocab != opts.TgtVocab))
                        {
                            throw new ArgumentException("The source and target vocabularies must be identical if their embeddings are shared.");
                        }

                        // Vocabulary files are specified, so we load them
                        srcVocab = new Vocab(opts.SrcVocab);
                        tgtVocab = new Vocab(opts.TgtVocab);
                    }
                    else
                    {
                        Logger.WriteLine($"Building vocabulary from training corpus. Shared vocabulary is '{opts.SharedEmbeddings}'");
                        // We don't specify vocabulary, so we build it from train corpus

                        (srcVocab, tgtVocab) = trainCorpus.BuildVocabs(opts.SrcVocabSize, opts.TgtVocabSize, opts.SharedEmbeddings);
                    }

                    //New training
                    ss = new Seq2Seq(opts, srcVocab, tgtVocab);
                }

                // Add event handler for monitoring
                ss.StatusUpdateWatcher += Misc.Ss_StatusUpdateWatcher;
                ss.EvaluationWatcher += Ss_EvaluationWatcher;

                // Kick off training
                ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpusList: validCorpusList.ToArray(), learningRate: learningRate, optimizer: optimizer, metrics: metrics);
            }
            else if (mode == ModeEnums.Valid)
            {
                Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPaths}'");

                // Create metrics
                List<IMetric> metrics = CreateMetrics();

                // Load valid corpus
                Seq2SeqCorpus validCorpus = new Seq2SeqCorpus(opts.ValidCorpusPaths, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxTestSrcSentLength, opts.MaxTestTgtSentLength, shuffleEnums: shuffleType, tooLongSequence: tooLongSequence);

                ss = new Seq2Seq(opts);
                ss.EvaluationWatcher += Ss_EvaluationWatcher;
                ss.Valid(validCorpus: validCorpus, metrics: metrics);
            }
            else if (mode == ModeEnums.Test)
            {
                if (File.Exists(opts.OutputFile))
                {
                    Logger.WriteLine(Logger.Level.err, ConsoleColor.Yellow, $"Output file '{opts.OutputFile}' exist. Delete it.");
                    File.Delete(opts.OutputFile);
                }

                //Test trained model
                ss = new Seq2Seq(opts);
                Stopwatch stopwatch = Stopwatch.StartNew();
                ss.Test<Seq2SeqCorpusBatch>(opts.InputTestFile, opts.OutputFile, opts.BatchSize, opts.MaxTestSrcSentLength);

                stopwatch.Stop();

                Logger.WriteLine($"Test mode execution time elapsed: '{stopwatch.Elapsed}'");
            }

``