facebookresearch / detectron2

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
https://detectron2.readthedocs.io/en/latest/
Apache License 2.0
29.32k stars 7.32k forks source link

refactor create local process group code for distributed training in "detectron2/utils/comm.py" #5202

Open zhuyuedlut opened 5 months ago

zhuyuedlut commented 5 months ago

🚀 Feature

In "detectron2/utils/comm.py", there is a function that name is "create_local_process_group". The funciton is bellow:

@functools.lru_cache()
def create_local_process_group(num_workers_per_machine: int) -> None:
    """
    Create a process group that contains ranks within the same machine.

    Detectron2's launch() in engine/launch.py will call this function. If you start
    workers without launch(), you'll have to also call this. Otherwise utilities
    like `get_local_rank()` will not work.

    This function contains a barrier. All processes must call it together.

    Args:
        num_workers_per_machine: the number of worker processes per machine. Typically
          the number of GPUs.
    """
    global _LOCAL_PROCESS_GROUP
    assert _LOCAL_PROCESS_GROUP is None
    assert get_world_size() % num_workers_per_machine == 0
    num_machines = get_world_size() // num_workers_per_machine
    machine_rank = get_rank() // num_workers_per_machine
    for i in range(num_machines):
        ranks_on_i = list(range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine))
        pg = dist.new_group(ranks_on_i)
        if i == machine_rank:
            _LOCAL_PROCESS_GROUP = pg

I think the function is create dist.new_group, and set LOCAL_PROGRESS_GROUP according to machine_rank. So they create some useless pg that i is not equal machine_rank

Motivation & Examples

I think the function may need refactor like bellow:

@functools.lru_cache()
def create_local_process_group(num_workers_per_machine: int) -> None:
    global _LOCAL_PROCESS_GROUP
    assert _LOCAL_PROCESS_GROUP is None
    assert get_world_size() % num_workers_per_machine == 0
    num_machines = get_world_size() // num_workers_per_machine
    machine_rank = get_rank() // num_workers_per_machine
    for i in range(num_machines):
        if i == machine_rank:
            ranks_on_i = list(range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine))
            pg = dist.new_group(ranks_on_i)
            _LOCAL_PROCESS_GROUP = pg
ppwwyyxx commented 4 months ago

Your suggestion doesn't follow the requirement of pytorch's dist.new_group:

image