facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.59k stars 6.41k forks source link

How to use fairseq.models.transformer.TransformerModel.from_pretrained on Multilingual translation model #5470

Closed JoeyandLucy closed 7 months ago

JoeyandLucy commented 7 months ago

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

I need to load a multilingual translation model using the fairseq.models.transformer.TransformerModel.from_pretrained function, but I don't know how to fill in the parameters for this function, and the official docs have a very abbreviated description of the parameters

Code

import torch
from fairseq.models import transformer
from fairseq.data import encoders, dictionaries

# Define the device, use GPU if available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Loading my own Transformer model
model = transformer.TransformerModel.from_pretrained(
    'fairseq/examples/translation/ Transformer/ wmt16.en-de.joined-dict',
    checkpoint_file='base_tag2_cts_lee checkpoint_avg_5.pt',
    bpe='sentencepiece',
    # data_name_or_path='wmt16',
    dict_file=[('en', 'data-bin/dict.src.txt'), ('de', 'data-bin/dict.tgt.txt')],
    encoder_langtok=None,
    decoder_langtok=None,
    join_dict=True
)

# Transfer the model to the GPU (if available)
model.to(device)

# Prepare to enter a sentence
sentence = 'Hello, how are you?'
encoder = encoders.get_permissive_encoder('en', 'data-bin/dict.src.txt')
tokens = encoder.encode(sentence, add_bos=True, add_eos=True)
src_idx = torch.LongTensor([tokens]).to(device)

# translate
with torch.no_grad():
    hypo = model.inference(src_idx)
    hypo_tokens = [hypo[s].item() for s in hypo]

# Converting translated tokens back to text
decoder = dictionaries.get_permissive_decoder('de', 'data-bin/dict.tgt.txt')
translation = decoder.decode(hypo_tokens)

print("Input sentence:", sentence)
print("Translation:", translation)

When I run this script, an error occurs

2024-04-02 12:29:06 | INFO | fairseq.file_utils | loading archive file ../
Traceback (most recent call last):
  File "mytranslate2.py", line 14, in <module>
    dict_file=[('<en>', '../data-bin/dict.src.txt'), ('<de>', '../data-bin/dict.tgt.txt')],
  File "/data1/liuchang/temp_data/fairseq-lee/fairseq/models/fairseq_model.py", line 272, in from_pretrained
    **kwargs,
  File "/data1/liuchang/temp_data/fairseq-lee/fairseq/hub_utils.py", line 75, in from_pretrained
    arg_overrides=kwargs,
  File "/data1/liuchang/temp_data/fairseq-lee/fairseq/checkpoint_utils.py", line 430, in load_model_ensemble_and_task
    task = tasks.setup_task(cfg.task)
  File "/data1/liuchang/temp_data/fairseq-lee/fairseq/tasks/__init__.py", line 46, in setup_task
    return task.setup_task(cfg, **kwargs)
  File "/data1/liuchang/temp_data/fairseq-lee/fairseq/tasks/translation.py", line 304, in setup_task
    "Could not infer language pair, please provide it explicitly"
Exception: Could not infer language pair, please provide it explicitly

However, the only parameter description for this function in the official documentation is as follows

Load a FairseqModel from a pre-trained model file. Downloads and caches the pre-trained model file if needed. The base implementation returns a GeneratorHubInterface, which can be used to generate translations or sample from language models. The underlying FairseqModel can be accessed via the generator.models attribute. Other models may override this to implement custom hub interfaces. Parameters:model_name_or_path (str) – either the name of a pre-trained model to load or a path/URL to a pre-trained model state dict checkpoint_file (str, optional) – colon-separated list of checkpoint files in the model archive to ensemble (default: ‘model.pt’) data_name_or_path (str, optional) – point args.data to the archive at the given path/URL. Can start with ‘.’ or ‘./’ to reuse the model archive path.

What have you tried?

I've tried changing the form of the label a couple of times, but it still does the same thing.

What's your environment?