Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.54k stars 3.39k forks source link

Returning num_replicas=world_size when using distributed sampler in ddp #19961

Open arjunagarwal899 opened 5 months ago

arjunagarwal899 commented 5 months ago

Bug description

The default LightningEnvironment assumes that every node in a multi-node environment has equal number of GPUs i.e. each node assumes that the world size is equal to number of nodes multiplied by the number of (active) devices on that node.

However, implementing one's own environment can bypass this limitation (example attached below). While the processes get registered successfully, the attribute num_replicas that is provided to the DistributedSampler class is still initialized independently of the environment, which leads to an error of having ranks outside the scope of the world size.

Fix: Use num_replicas=self.world_size() instead of estimating the world size again.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

Config and custom Environment:

# Code that sets some config parameters as follows
config.nodes = [
    ("NODE_NAME", ["NODE_IP_ADDRESS", "NUM_GPUS: int"]),  # Index 0 is master node
    ...
]
config.num_nodes = len(config.nodes)

# Set some global variables
MASTER_PORT = 10051 # Set port here
MASTER_ADDR = config.nodes[0][1][0] # Set address here
WORLD_SIZE = sum([node_info[1] for _, node_info in config.nodes]) # Set world size here

# Set config devices, NODE_RANK, and global rank starting point
NODE_RANK = ...
GLOBAL_RANK_OFFSET = 0
for i, (node, node_info) in enumerate(config.nodes):
    if node == socket.gethostname():
        config.devices = node_info[1]
        NODE_RANK = i
        break
    GLOBAL_RANK_OFFSET += node_info[1]

# Set environment variables
os.environ["MASTER_ADDR"] = str(MASTER_ADDR)
os.environ["MASTER_PORT"] = str(MASTER_PORT)
os.environ["WORLD_SIZE"] = str(WORLD_SIZE)
os.environ["NODE_RANK"] = str(NODE_RANK)

class MyClusterEnvironment(LightningEnvironment):
    def set_world_size(self, size: int):
        # Here, size = num_nodes * len(devices)  which does not work for heterogenous clusters
        self._world_size = WORLD_SIZE

    def set_global_rank(self, rank: int):
        # Here, global_rank = node_rank * len(devices) + local_rank  which does not work for heterogenous clusters
        global_rank = GLOBAL_RANK_OFFSET + self.local_rank()
        self._global_rank = global_rank

config.cluster_environment = MyClusterEnvironment()

Trainer:

trainer = L.Trainer(
    num_nodes=config.num_nodes,
    devices=config.devices,
    plugins=[config.cluster_environment],
    ... # Other arguments
)

Run on:

Error messages and logs

Error on node 1:

╭────────────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────────────╮
│ /home/users/arjun.agarwal/projects/mock_training/distributed.py:99 in <module>                                                 │
│                                                                                                                                │
│    96 │   │   plugins=[config.cluster_environment],                                                                            │
│    97 │   )                                                                                                                    │
│    98 │                                                                                                                        │
│ ❱  99 │   trainer.fit(model, dm)                                                                                               │
│   100                                                                                                                          │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:544 in fit               │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:43 in                       │
│ _call_and_handle_interrupt                                                                                                     │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py:1 │
│ 05 in launch                                                                                                                   │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:580 in _fit_impl         │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:987 in _run              │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1031 in _run_stage       │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1060 in                  │
│ _run_sanity_check                                                                                                              │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/loops/utilities.py:182 in _decorator        │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/loops/evaluation_loop.py:110 in run         │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/loops/evaluation_loop.py:180 in setup_data  │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:501 in │
│ _process_dataloader                                                                                                            │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:206 in │
│ _prepare_dataloader                                                                                                            │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:217 in │
│ _resolve_sampler                                                                                                               │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:258 in │
│ _get_distributed_sampler                                                                                                       │
│                                                                                                                                │
│ /home/users/arjun.agarwal/miniconda3/lib/python3.9/site-packages/torch/utils/data/distributed.py:74 in __init__                │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Invalid rank 4, rank should be in the interval [0, 3]

num_replicas gets set to 4 here as num_nodes=2 and num_processes=2. However world size is 6 as defined in the environment.

Environment

Current environment ``` * Lightning: - efficientnet-pytorch: 0.7.1 - lightning: 2.2.5 - lightning-cloud: 0.5.61 - lightning-utilities: 0.10.0 - pytorch-lightning: 2.1.2 - pytorchvideo: 0.1.5 - torch: 2.2.2 - torchaudio: 2.2.2 - torchmetrics: 1.2.1 - torchsummary: 1.5.1 - torchvision: 0.17.2 ```

More info

The issue can be fixed by replacing ddp.py:L137

    @property
    @override
    def distributed_sampler_kwargs(self) -> Dict[str, Any]:
        return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank}

with


    @property
    @override
    def distributed_sampler_kwargs(self) -> Dict[str, Any]:
        return {"num_replicas": self.world_size, "rank": self.global_rank}

cc @borda @awaelchli @justusschock

arjunagarwal899 commented 4 months ago

@lantiga @Borda @tchaton @awaelchli @justusschock Any feedback on this?

awaelchli commented 4 months ago

Thanks for the interest in this @arjunagarwal899 It's not a bug because Lightning assumes this setting in several places with good intention. The docs and examples never show any other setting, so it's definitely not expected to work. The change you propose in the PR is unfortunately not enough (it may be enough in your case). More places would need updates, including thorough testing and extending tests (feel free to give it a try). We won't have bandwidth to help with this in the near future, but if the community is willing to put the effort into it, we'd be happy to review PRs for it of course!

Other issues: #19898, #14078

shaibagon commented 3 months ago

@awaelchli @arjunagarwal899 I am facing similar issue. I do not have the capacity to write this PR, but I may be able to help testing it.