pytorch / PiPPy

Pipeline Parallelism for PyTorch
BSD 3-Clause "New" or "Revised" License
715 stars 86 forks source link

Adding 'labels' input to model with 'include_loss_args' fails hf examples #1119

Open alexlan137 opened 4 months ago

alexlan137 commented 4 months ago

Hi,

I'm trying to use PiPPy with a custom model that takes both 'input_ids' and 'labels' as inputs. To check for this functionality, I modified the basic pippy_gpt2.py example by first changing the model_class and model_name to GPT2LMHeadModel and then setting setting include_loss_args to True in the function call used to generate example_inputs: example_inputs = generate_inputs_for_model(model_class, gpt2, model_name, args.batch_size, args.device, include_loss_args=True)

However, this fails with the following traceback:

[rank0]: TypeError: forward() got an unexpected keyword argument 'labels'
RuntimeError: 
[rank0]:             [Stage 0] failed to run forward:
[rank0]:             args: ()
[rank0]:             kwargs: {'input_ids': 'Tensor(torch.Size([1, 1024]), grad=False)', 'labels': 'Tensor(torch.Size([1, 1024]), grad=False)'}

This occurs because PiPPy splits the graph module (split_gm) such that the labels input is sent to the last (4th) submodule, so the first submodule is not expecting an input 'labels'.

I also tried to modify pippy_gpt2.py to insert the labels at the last submodule in schedule.step as follows (although this is not optimal as a long-term solution):

input_values = torch.randint(0, 50257, (args.batch_size, 1024), device=args.device)
example_inputs_0 = {"input_ids": input_values}
example_inputs_3 = {"labels": input_values}

# Run
if args.rank == 0:
    schedule.step(**example_inputs_0)
elif args.rank == 3:
    schedule.step(**example_inputs_3)
else:
    out = schedule.step()

This throws the following error, probably because internal submodules expect RecvInfo and tensors from previous layers rather than new values from input placeholders? [rank3]: AssertionError: Expected RecvInfo but got <class 'torch.distributed.pipelining._PipelineStage.RootArgPlaceholder'> I could try to debug further, but is there a better solution or does anyone have any ideas for how to implement this? Thanks.