huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.92k stars 966 forks source link

AttributeError: 'DistributedDataParallel' object has no attribute 'generate' #109

Closed Ravoxsg closed 3 years ago

Ravoxsg commented 3 years ago

Hi,

I am trying to generate text with the model.generate() function of HuggingFace. However, it does not work with neither 1 or more than 1 GPUs, and I get the following error:

Traceback (most recent call last): File "ft_main.py", line 234, in main(args) File "ft_main.py", line 169, in main training_loop_accelerate(train_loader, val_loader, test_loader, tokenizer, model, optimizer, scheduler, accelerator, device, args) File "/home/mathieu/multilingual_summarization/multilingual_summarization/ft_engine_accelerate.py", line 20, in training_loop_accelerate scores = validate("val", val_loader, tokenizer, model, device, scores, args) File "/home/mathieu/multilingual_summarization/multilingual_summarization/ft_engine_accelerate.py", line 148, in validate summary_ids = model.generate( File "/home/mathieu/anaconda3/envs/mlsumm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in getattr raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'DistributedDataParallel' object has no attribute 'generate' ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 44726) of binary: /home/mathieu/anaconda3/envs/mlsumm/bin/pytho$ ERROR:torch.distributed.elastic.agent.server.local_elastic_agent:[default] Worker group failed

I don't have any error when not in generation mode though. My setup is the following:

I am using MT5ForConditionalGeneration as model.

And here is my validate function:

`def validate(mode, loader, tokenizer, model, device, scores, args): model.eval()

losses = []
times = []
all_r1 = []
all_rl = []
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=False)

for idx, batch in tqdm(enumerate(loader)):
    t1 = time()
    text = batch["text"]
    text_lang = batch["text_lang"]
    text_inputs = batch["text_inputs"]
    summary = batch["summary"]
    summary_lang = batch["summary_lang"]
    summary_inputs = batch["summary_inputs"]

    for k in text_inputs.keys():
        text_inputs[k] = text_inputs[k].squeeze(1).to(device)
        summary_inputs[k] = summary_inputs[k].squeeze(1).to(device)

    outputs = model(**text_inputs, labels=summary_inputs["input_ids"])

    loss = outputs["loss"]
    losses.append(loss.item())
    times.append(time() - t1)

    del outputs
    del text_inputs
    del summary_inputs
    gc.collect()

    for i in range(len(summary)):
        src_lang = text_lang[i]
        tgt_lang = summary_lang[i]

        inputs = tokenizer(text[i], return_tensors="pt")
        inputs["input_ids"] = inputs["input_ids"][:, :args.max_length].to(device)
        inputs["attention_mask"] = inputs["attention_mask"][:, :args.max_length].to(device)

        if args.model_type == "mt5":
            summary_ids = model.generate(
                inputs['input_ids'],
                attention_mask=inputs["attention_mask"],
                use_cache=True,
                num_beams=args.num_beams,
                max_length=args.max_summary_length,
                early_stopping=True,
                repetition_penalty=args.repetition_penalty,
                length_penalty=args.length_penalty
            )`
sgugger commented 3 years ago

You have not explained how you created your model or how you launch this script. The stack trace indicates there is a DistributedDataParallel involved and some distributed launch. As the error indicates the model wrapped in a DistributedDataParallel is not a Transformers model anymore and has no generate method, you need to access the model you wrapped like this: model.module to get back you Transformers model. So model.module.generate.

Ravoxsg commented 3 years ago

Calling model.module.generate() indeed fixed it, thank you @sgugger !

sgugger commented 3 years ago

Closing this issue then, glad your problem is fixed!