Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.1k stars 3.36k forks source link

Does `DDPStrategy` support XLA? #19766

Closed laserkelvin closed 2 months ago

laserkelvin commented 5 months ago

Bug description

When configuring a DDPStrategy with multiple devices that do not use the torch.cuda API, we trigger the following exception:

  File "/home/hpclee1/rds/hpc-work/.conda/envs/matsciml/lib/python3.10/site-packages/torch/cuda/_utils.py", line 46, in err_fn
    raise RuntimeError(
RuntimeError: Tried to instantiate dummy base class Stream

The _setup_model method of DDPStrategy triggers this exception, as torch.cuda.stream is hardcoded if device_ids are passed. I've reproduced the snippet below, but here is a permalink.

    @override
    def _setup_model(self, model: Module) -> DistributedDataParallel:
        """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
        device_ids = self.determine_ddp_device_ids()
        log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
        # https://pytorch.org/docs/stable/notes/cuda.html#id5
        ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
        with ctx:
            return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

A potential solution could be checking the target device, or even just checking torch.cuda.is_available() for the condition. Removing the torch.cuda.Stream() call and just using the nullcontext() functions perfectly fine otherwise.

The snippet provided below relies on an XPUAccelerator registered here, but I would assume this might trigger for other accelerators as well.

What version are you seeing the problem on?

v2.1, v2.2

How to reproduce the bug

env = pl.plugins.environments.SLURMEnvironment()
ddp = pl.strategies.DDPStrategy(
    accelerator="xpu",
    cluster_environment=env,
    process_group_backend="ccl",
    find_unused_parameters=True
)

trainer = pl.Trainer(
    strategy=ddp, devices=num_devices, fast_dev_run=100, num_nodes=num_nodes
)
trainer.fit(task, datamodule=dm)

Error messages and logs

  File "/home/hpclee1/rds/hpc-work/.conda/envs/matsciml/lib/python3.10/site-packages/torch/cuda/_utils.py", line 46, in err_fn
    raise RuntimeError(
RuntimeError: Tried to instantiate dummy base class Stream

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): 2.2.1 #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): 2.0.1 #- Python version (e.g., 3.9): 3.10 #- OS (e.g., Linux): Linux #- CUDA/cuDNN version: N/A #- GPU models and configuration: Intel 1550 Data Center GPU Max #- How you installed Lightning(`conda`, `pip`, source): pip #- Running environment of LightningApp (e.g. local, cloud): Managed Slurm cluster ```

More info

No response

cc @justusschock @awaelchli

awaelchli commented 3 months ago

Hi @laserkelvin

The DDPStrategy does not support XLA, nor does the DDP implementation in PyTorch. For distributed training with XLA, please use

Trainer(accelerator="tpu", devices=8)

Docs: https://lightning.ai/docs/pytorch/stable/accelerators/tpu.html

We won't be able to support XLA+DDP like you requested.