meta-llama / llama

Inference code for Llama models
Other
55.69k stars 9.5k forks source link

AssertionError: model parallel group is not initialized #306

Open WencongY opened 1 year ago

WencongY commented 1 year ago

Hello team,

I'm trying to run the example.py file with 7B on a single GPU with this command torchrun --nproc_per_node 1 example.py --ckpt_dir ./llama_model/7B --tokenizer_path ./llama_model/tokenizer.model, but I've got the following error:

Traceback (most recent call last):
  File "/gpfs/projects/user/LLM/llama/example.py", line 119, in <module>
    fire.Fire(main)
  File "/packages/miniconda/20190102/envs/user/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/packages/miniconda/20190102/envs/user/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/packages/miniconda/20190102/envs/user/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/gpfs/projects/user/LLM/llama/example.py", line 78, in main
    generator = load(
  File "/gpfs/projects/user/LLM/llama/example.py", line 57, in load
    model = Transformer(model_args)
  File "/gpfs/projects/user/LLM/llama/llama/model.py", line 205, in __init__
    self.tok_embeddings = ParallelEmbedding(
  File "/packages/miniconda/20190102/envs/user/lib/python3.10/site-packages/fairscale/nn/model_parallel/layers.py", line 186, in __init__
    world_size = get_model_parallel_world_size()
  File "/packages/miniconda/20190102/envs/user/lib/python3.10/site-packages/fairscale/nn/model_parallel/initialize.py", line 152, in get_model_parallel_world_size
    return torch.distributed.get_world_size(group=get_model_parallel_group())
  File "/packages/miniconda/20190102/envs/user/lib/python3.10/site-packages/fairscale/nn/model_parallel/initialize.py", line 128, in get_model_parallel_group
    assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
AssertionError: model parallel group is not initialized
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 29705) of binary: /packages/miniconda/20190102/envs/user/bin/python
Traceback (most recent call last):
  File "/packages/miniconda/20190102/envs/user/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.0.0', 'console_scripts', 'torchrun')())
  File "/packages/miniconda/20190102/envs/user/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/packages/miniconda/20190102/envs/user/lib/python3.10/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/packages/miniconda/20190102/envs/user/lib/python3.10/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/packages/miniconda/20190102/envs/user/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/packages/miniconda/20190102/envs/user/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

Can you please advise how to handle this?

Thanks!

yhLeeee commented 1 year ago

Have you fixed this? Thx!

loretoparisi commented 4 weeks ago

Have you fixed this? Thx!

This is caused by missing parallel group init for torch.distributed that requires at least:

def setup_model_parallel(rank, master_addr, master_port, world_size, backend='nccl') -> Tuple[int, int]:
    '''
        this will not work with LightningModule
    '''
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    world_size = int(os.environ.get("WORLD_SIZE", "4"))
    print("local_rank:", local_rank, "world_size:", world_size)

    torch.distributed.init_process_group(backend)
    initialize_model_parallel(world_size)
    torch.cuda.set_device(local_rank)

    # seed must be the same in all processes
    torch.manual_seed(1)
    return local_rank, world_size

and

import torch.multiprocessing as mp
local_rank, world_size = setup_model_parallel()
mp.spawn(setup_model_parallel, args=(master_addr,master_port,world_size,), nprocs=world_size)
model = Llama3()