I need to load a multilingual translation model using the fairseq.models.transformer.TransformerModel.from_pretrained function, but I don't know how to fill in the parameters for this function, and the official docs have a very abbreviated description of the parameters
Code
import torch
from fairseq.models import transformer
from fairseq.data import encoders, dictionaries
# Define the device, use GPU if available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Loading my own Transformer model
model = transformer.TransformerModel.from_pretrained(
'fairseq/examples/translation/ Transformer/ wmt16.en-de.joined-dict',
checkpoint_file='base_tag2_cts_lee checkpoint_avg_5.pt',
bpe='sentencepiece',
# data_name_or_path='wmt16',
dict_file=[('en', 'data-bin/dict.src.txt'), ('de', 'data-bin/dict.tgt.txt')],
encoder_langtok=None,
decoder_langtok=None,
join_dict=True
)
# Transfer the model to the GPU (if available)
model.to(device)
# Prepare to enter a sentence
sentence = 'Hello, how are you?'
encoder = encoders.get_permissive_encoder('en', 'data-bin/dict.src.txt')
tokens = encoder.encode(sentence, add_bos=True, add_eos=True)
src_idx = torch.LongTensor([tokens]).to(device)
# translate
with torch.no_grad():
hypo = model.inference(src_idx)
hypo_tokens = [hypo[s].item() for s in hypo]
# Converting translated tokens back to text
decoder = dictionaries.get_permissive_decoder('de', 'data-bin/dict.tgt.txt')
translation = decoder.decode(hypo_tokens)
print("Input sentence:", sentence)
print("Translation:", translation)
When I run this script, an error occurs
2024-04-02 12:29:06 | INFO | fairseq.file_utils | loading archive file ../
Traceback (most recent call last):
File "mytranslate2.py", line 14, in <module>
dict_file=[('<en>', '../data-bin/dict.src.txt'), ('<de>', '../data-bin/dict.tgt.txt')],
File "/data1/liuchang/temp_data/fairseq-lee/fairseq/models/fairseq_model.py", line 272, in from_pretrained
**kwargs,
File "/data1/liuchang/temp_data/fairseq-lee/fairseq/hub_utils.py", line 75, in from_pretrained
arg_overrides=kwargs,
File "/data1/liuchang/temp_data/fairseq-lee/fairseq/checkpoint_utils.py", line 430, in load_model_ensemble_and_task
task = tasks.setup_task(cfg.task)
File "/data1/liuchang/temp_data/fairseq-lee/fairseq/tasks/__init__.py", line 46, in setup_task
return task.setup_task(cfg, **kwargs)
File "/data1/liuchang/temp_data/fairseq-lee/fairseq/tasks/translation.py", line 304, in setup_task
"Could not infer language pair, please provide it explicitly"
Exception: Could not infer language pair, please provide it explicitly
However, the only parameter description for this function in the official documentation is as follows
Load a FairseqModel from a pre-trained model file. Downloads and caches the pre-trained model file if needed.
The base implementation returns a GeneratorHubInterface, which can be used to generate translations or sample from language models. The underlying FairseqModel can be accessed via the generator.models attribute.
Other models may override this to implement custom hub interfaces.
Parameters:model_name_or_path (str) – either the name of a pre-trained model to load or a path/URL to a pre-trained model state dict
checkpoint_file (str, optional) – colon-separated list of checkpoint files in the model archive to ensemble (default: ‘model.pt’)
data_name_or_path (str, optional) – point args.data to the archive at the given path/URL. Can start with ‘.’ or ‘./’ to reuse the model archive path.
What have you tried?
I've tried changing the form of the label a couple of times, but it still does the same thing.
What's your environment?
fairseq Version (1.0.0a0):
PyTorch Version (1.13.1+cu117)
OS (Linux):
How you installed fairseq (git hub clone):
Build command you used (if compiling from source):
❓ Questions and Help
Before asking:
What is your question?
I need to load a multilingual translation model using the fairseq.models.transformer.TransformerModel.from_pretrained function, but I don't know how to fill in the parameters for this function, and the official docs have a very abbreviated description of the parameters
Code
When I run this script, an error occurs
However, the only parameter description for this function in the official documentation is as follows
Load a FairseqModel from a pre-trained model file. Downloads and caches the pre-trained model file if needed. The base implementation returns a GeneratorHubInterface, which can be used to generate translations or sample from language models. The underlying FairseqModel can be accessed via the generator.models attribute. Other models may override this to implement custom hub interfaces. Parameters:model_name_or_path (str) – either the name of a pre-trained model to load or a path/URL to a pre-trained model state dict checkpoint_file (str, optional) – colon-separated list of checkpoint files in the model archive to ensemble (default: ‘model.pt’) data_name_or_path (str, optional) – point args.data to the archive at the given path/URL. Can start with ‘.’ or ‘./’ to reuse the model archive path.
What have you tried?
I've tried changing the form of the label a couple of times, but it still does the same thing.
What's your environment?