pytorch / PiPPy

Pipeline Parallelism for PyTorch
BSD 3-Clause "New" or "Revised" License
726 stars 86 forks source link

Code hangs permanently #1138

Open Narasimha1997 opened 3 months ago

Narasimha1997 commented 3 months ago

I was experimenting loading qwen2 model with world-size 2. I am loading the workers completely on CPU. The following is the code I was testing:

import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.pipelining import SplitPoint, pipeline, ScheduleGPipe

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer.pad_token = tokenizer.eos_token

def load_model(rank: int):
    model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

    mb_inputs = tokenizer(("How do you",),
                          return_tensors="pt", padding=True).to(torch.device("cpu"))
    pipe = pipeline(model, mb_args=(mb_inputs["input_ids"],), split_spec={
                    'model.layers.12': SplitPoint.BEGINNING}
                    )

    stage = pipe.build_stage(rank, device=torch.device("cpu"))
    return stage

full_batch_prompts = (
    "How do you",
)

inputs = tokenizer(full_batch_prompts, return_tensors="pt",
                   padding=False).to(torch.device("cpu")
                                     )

rank = int(os.getenv("RANK"))

torch.distributed.init_process_group(
    "gloo",
    rank=rank, world_size=2
)

stage = load_model(rank)

print('loaded model, now initiating pipeline')
schedule = ScheduleGPipe(stage, 1)

if rank == 0:
    args = inputs["input_ids"]
    print(args.shape, args)
else:
    args = None

output = schedule.step(args)
print(f'{rank} - op - {output}')

if output is not None:
    next_token_logits = output[0][:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1)
    print(tokenizer.batch_decode(next_token))

The code hangs infinitely. However I got the output from rank 0.