facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.43k stars 6.4k forks source link

XGLM Example does not run #4210

Open afcruzs-ms opened 2 years ago

afcruzs-ms commented 2 years ago

🐛 Bug

The XGLM example has several issues: https://github.com/pytorch/fairseq/tree/main/examples/xglm: 1) fairseq.models.build_model fails because the model_type (transformer_lm_gpt2_big_wide) does not exist on the registered models, 2) the issue in #4209.

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

  1. Run the examples on https://github.com/pytorch/fairseq/tree/main/examples/xglm
  2. You'll see an error complaining "_name" does not exist, which comes from the fact the model object in the build_model function is None. This can be worked around by hardcoding model_type = 'transformer_lm'.
  3. With the workaround in 2), you'll see the issue on #4209. This can be mitigated by removing the pad_length arguments to PadDataset in build_dataset_for_inference.

With this I'm able to run, however it's unclear if this hacks guarantee the correctness of the model.

Stack trace for error in 2):

Traceback (most recent call last):
  File "go.py", line 55, in <module>
    lm = TransformerLanguageModel.from_pretrained(model_dir, bpe='sentencepiece')
  File "/home/ancruzsa/fairseq/fairseq/models/fairseq_model.py", line 267, in from_pretrained
    x = hub_utils.from_pretrained(
  File "/home/ancruzsa/fairseq/fairseq/hub_utils.py", line 73, in from_pretrained
    models, args, task = checkpoint_utils.load_model_ensemble_and_task(
  File "/home/ancruzsa/fairseq/fairseq/checkpoint_utils.py", line 469, in load_model_ensemble_and_task
    model = task.build_model(cfg.model, from_checkpoint=True)
  File "/home/ancruzsa/fairseq/fairseq/tasks/multilingual_language_modeling.py", line 260, in build_model
    model = super().build_model(args, from_checkpoint)
  File "/home/ancruzsa/fairseq/fairseq/tasks/fairseq_task.py", line 671, in build_model
    model = models.build_model(args, self, from_checkpoint)
  File "/home/ancruzsa/fairseq/fairseq/models/__init__.py", line 102, in build_model
    f"Could not infer model type from {cfg}. "
KeyError: "'_name'"

Stack trace for error 3):

Traceback (most recent call last):
  File "go.py", line 75, in <module>
    predict = COPA_eval(example["premise"], example["choice1"], example["choice2"])
  File "go.py", line 69, in COPA_eval
    lprob1 = get_logprobs(prompt + "\n" + alternative1).sum()
  File "go.py", line 63, in get_logprobs
    return lm.score(prompt, replace_newlines_with_eos=True)['positional_scores']
  File "/home/ancruzsa/fairseq/fairseq/hub_utils.py", line 139, in score
    return self.score(
  File "/home/ancruzsa/fairseq/fairseq/hub_utils.py", line 153, in score
    for hypos in self.generate(
  File "/home/ancruzsa/fairseq/fairseq/hub_utils.py", line 187, in generate
    for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
  File "/home/ancruzsa/fairseq/fairseq/hub_utils.py", line 275, in _build_batches
    dataset=self.task.build_dataset_for_inference(tokens, lengths),
  File "/home/ancruzsa/fairseq/fairseq/tasks/multilingual_language_modeling.py", line 533, in build_dataset_for_inference
    "src_tokens": PadDataset(
TypeError: __init__() got an unexpected keyword argument 'pad_length'

Code sample

from fairseq.models.transformer_lm import TransformerLanguageModel

data_samples = {
    'en': [
        {
            "premise": "I wanted to conserve energy.", 
            "choice1": "I swept the floor in the unoccupied room.", 
            "choice2": "I shut off the light in the unoccupied room.",
            "question": "effect",
            "label": "1"
        },
        {
            "premise": "The flame on the candle went out.",
            "choice1": "I blew on the wick.", 
            "choice2": "I put a match to the wick.",
            "question": "cause",
            "label": "0"
        }
    ],
    'zh': [
        {
            "premise": "我想节约能源。", 
            "choice1": "我在空着的房间里扫了地板。", 
            "choice2": "我把空房间里的灯关了。",
            "question": "effect",
            "label": "1"
        },
        {
            "premise": "蜡烛上的火焰熄灭了。",
            "choice1": "我吹灭了灯芯。", 
            "choice2": "我把一根火柴放在灯芯上。",
            "question": "cause",
            "label": "0"
        }
    ],
    'hi': [
        {
            "premise": "M te vle konsève enèji.", 
            "choice1": "Mwen te fin baleye chanm lib la.", 
            "choice2": "Mwen te femen limyè nan chanm lib la.",
            "question": "effect",
            "label": "1"
        },
        {
            "premise": "Flam bouji a te etenn.",
            "choice1": "Mwen te soufle bouji a.", 
            "choice2": "Mwen te limen mèch bouji a.",
            "question": "cause",
            "label": "0"
        }
    ]
}

model_dir = # replace here with the path to xglm.564M.tar.gz
lm = TransformerLanguageModel.from_pretrained(model_dir, bpe='sentencepiece')
lm = lm.eval()
lm = lm.half()
lm = lm.cuda()

def get_logprobs(prompt):
    import re
    prompt = re.sub('\n+' , '\n', prompt)  # collapse repeated newlines, which indicate separate documents
    return lm.score(prompt, replace_newlines_with_eos=True)['positional_scores']

# Zero-shot evaluation for the Choice of Plausible Alternatives (COPA) task.
# A return value of 0 indicates that the first alternative is more plausible,
# while 1 indicates that the second alternative is more plausible.
def COPA_eval(prompt, alternative1, alternative2):
    lprob1 = get_logprobs(prompt + "\n" + alternative1).sum()
    lprob2 = get_logprobs(prompt + "\n" + alternative2).sum()
    return 0 if lprob1 > lprob2 else 1

for lang in ['en', 'zh', 'hi']:
    for idx, example in enumerate(data_samples[lang]):
        predict = COPA_eval(example["premise"], example["choice1"], example["choice2"])
        print(f'{lang}-{idx}', predict, example['label'])

Expected behavior

Model should be able to load and run correctly as trained.

Environment

Additional context

jungokasai commented 2 years ago

It looks that this has been resolved in the latest main branch!

todpole3 commented 2 years ago

@afcruzs-ms Would you please close this issue if the problem has been solved? Thanks!