Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.03k stars 3.36k forks source link

Support user-defined parallelization in the LightningModule #11922

Closed ananthsub closed 4 months ago

ananthsub commented 2 years ago

🚀 Feature

Support a manual parallelization option

Motivation

Now that the Strategy refactor is complete, this unlocks a step change for research flexibility. Users no longer have to override 2 different classes (TrainingTypePlugin & Accelerator) to be able to implement custom parallelism handling, which widens the set of use cases Lightning can support as a training loop framework.

There are users who have highly customized parallelization requirements.

For instance:

Rather than require each of these users to learn all about the Strategy codebase to be able to customize this, I propose a "manual" parallel strategy which delegates this logic back to the LightningModule.

This way, all of the modeling logic sits in one place. This is easier for researchers to get started without needing to learn another abstraction. If these techniques pan out to be more general, they can be abstracted out to fit into the Strategy interface, which makes them shareable across projects.

In this setting, the user assumes responsibility for the following:

The Trainer/Strategy will still handle:

This is intended for power users who know exactly what they're doing. The terminology manual parallel follows the precedent of manual optimization: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#manual-optimization

This is also the motivation for the PRs removing dependencies on LightningModule.device within the Trainer:

LightningModule.device is not properly defined for use cases where the LightningModule's parameters sit on multiple devices. This proposal aims to remove the requirement for users of these LightningModules to call LightningModule.to(...) before executing a Trainer function.

Pitch

Define a new strategy class like this:

class ManualParallelStrategy(ParallelStrategy):
    def __init__(
        self,
        accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
        parallel_devices: Optional[List[torch.device]] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None,
    ):
        super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
        self.parallel_devices = parallel_devices
        self.cluster_environment = cluster_environment

    def setup_environment(self) -> None:
        # start the other scripts
        if not self.cluster_environment.creates_processes_externally:
            self._call_children_scripts()

        self.setup_distributed()
        super().setup_environment()

    def setup_distributed(self):
        # initialize process group if not already available

    @property
    def root_device(self) -> torch.device:
        """ The device where data is loaded to """
        return self.parallel_devices[self.local_rank]

    def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
        """Save model/training states as a checkpoint file through state-dump and file-write.
        Args:
            checkpoint: dict containing model and trainer state
            filepath: write-target file's path
        """
        # By default, enable saving on all ranks for distributed checkpointing
        self.checkpoint_io.save_checkpoint(checkpoint, filepath)

    def barrier(self, *args, **kwargs) -> None:
        if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
            torch.distributed.barrier(device_ids=self.determine_device_ids())
        else:
            torch.distributed.barrier()

    def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
        obj = [obj]
        if self.global_rank != src:
            obj = [None]
        torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
        return obj[0]

    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:
        """Reduces a tensor from several distributed processes to one aggregated tensor.
        Args:
            tensor: the tensor to sync and reduce
            group: the process group to gather results from. Defaults to all processes (world)
            reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
                Can also be a string 'sum' to calculate the sum during reduction.
        Return:
            reduced value, except when the input was not a tensor the output remains is unchanged
        """
        if isinstance(tensor, torch.Tensor):
            tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
        return tensor

    def teardown(self) -> None:
        """This method is called to teardown the training process.
        It is the right place to release memory and free other resources.
        """
        self.precision_plugin.teardown()
        self.cluster_environment.teardown()

Example of a LightningModule which is inherently distributed aware


class MyLightningModule(LightningModule):
    def __init__(self):
        rank = int(os.environ["LOCAL_RANK"])
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            backend = "nccl"
            torch.cuda.set_device(device)
        else:
            device = torch.device("cpu")
            backend = "gloo"

        if not torch.distributed.is_initialized():
            dist.init_process_group(backend=backend)
        model = MyHugeModel(....)  # might require process group to already be available
        self.model = shard_huge_model(model, device) # might require process group to already be available
        self.optimizer = MyOptimizer(self.model.parameters()) # model is already on the correct device, so this is safe to initalize now

    def configure_optimizers(self):
        return self.optimizer

trainer = Trainer(strategy="manual", accelerator="gpu", devices=8)
lit_model = MyLightningModule()
trainer.fit(lit_model)

Alternatives

Additional context

Idea for manual parallelization was also raised here: https://github.com/PyTorchLightning/pytorch-lightning/issues/8722#issuecomment-922699686


If you enjoy Lightning, check out our other projects! âš¡

cc @borda @awaelchli @rohitgr7 @akihironitta

carmocca commented 2 years ago

Should the model be wrapped instead in setup? It would avoid the following, right?

        rank = int(os.environ["LOCAL_RANK"])
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            backend = "nccl"
            torch.cuda.set_device(device)
        else:
            device = torch.device("cpu")
            backend = "gloo"

        if not torch.distributed.is_initialized():
            dist.init_process_group(backend=backend)

The code above (wrapping in __init__) won't work in DDP spawn, it would also be cleaner to let Lightning create the process group etc so the users just need to wrap and create the optimizers.

Should the creation of the process group be completely customizable for support with strategies like DeepSpeed or Bagua?

ananthsub commented 2 years ago

Initializing the model in setup has a few downsides:

  1. The model initialization & training logic are coupled together. Ideally we would have the model initialization external to the lightning module. This way, we use the lightning module as a system, as recommended by the docs. Otherwise, the lightning module needs to know how to initialize & shard the provided models. Different models may have different APIs/behaviors, all of which end up inside of setup (e.g. the MyHugeModel and shard_huge_model methods above).
  2. initializing the model inside of setup also runs into complications with loading checkpoints through load_from_checkpoint. This is a similar problem faced with FSDP when using configure_sharded_model to do the sharding.

The code above (wrapping in init) won't work in DDP spawn.

The LightningModule code is determined by the user. So they would have to determine if they require their code to work with DDP spawn vs not. I think it's going to be hard to support all of custom parallelization + checkpoint loading + spawning simultaneously.

it would also be cleaner to let Lightning create the process group etc so the users just need to wrap and create the optimizers. Should the creation of the process group be completely customizable for support with strategies like DeepSpeed or Bagua?

Users can already initialize the process group themselves if they create the processes externally. The only instances where Lightning has to create the process group is for spawn and subprocess script launch. A lighter form of customization is being worked on in https://github.com/PyTorchLightning/pytorch-lightning/pull/11745 .

Note: I don't want to make this an issue about supporting spawning vs not. This is only to state that relying on the lightning trainer to do the process creation imposes restrictions on how users author their training programs. From use cases I've seen, especially ones that would benefit from this strategy, we have been using torchx to great effect.

carmocca commented 2 years ago

Ideally we would have the model initialization external to the lightning module. This way, we use the lightning module as a system, as recommended by the docs.

The LightningModule.setup could call a nn.Module.setup defined by the user to avoid this

also runs into complications with loading checkpoints

This is a fair point. However, at least the docs for this feature should display both options and mention their differences.