Closed ananthsub closed 4 months ago
Should the model be wrapped instead in setup
? It would avoid the following, right?
rank = int(os.environ["LOCAL_RANK"])
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
backend = "gloo"
if not torch.distributed.is_initialized():
dist.init_process_group(backend=backend)
The code above (wrapping in __init__
) won't work in DDP spawn, it would also be cleaner to let Lightning create the process group etc so the users just need to wrap and create the optimizers.
Should the creation of the process group be completely customizable for support with strategies like DeepSpeed or Bagua?
Initializing the model in setup
has a few downsides:
setup
(e.g. the MyHugeModel
and shard_huge_model
methods above). load_from_checkpoint
. This is a similar problem faced with FSDP when using configure_sharded_model
to do the sharding.The code above (wrapping in init) won't work in DDP spawn.
The LightningModule code is determined by the user. So they would have to determine if they require their code to work with DDP spawn vs not. I think it's going to be hard to support all of custom parallelization + checkpoint loading + spawning simultaneously.
it would also be cleaner to let Lightning create the process group etc so the users just need to wrap and create the optimizers. Should the creation of the process group be completely customizable for support with strategies like DeepSpeed or Bagua?
Users can already initialize the process group themselves if they create the processes externally. The only instances where Lightning has to create the process group is for spawn and subprocess script launch. A lighter form of customization is being worked on in https://github.com/PyTorchLightning/pytorch-lightning/pull/11745 .
Note: I don't want to make this an issue about supporting spawning vs not. This is only to state that relying on the lightning trainer to do the process creation imposes restrictions on how users author their training programs. From use cases I've seen, especially ones that would benefit from this strategy, we have been using torchx to great effect.
Ideally we would have the model initialization external to the lightning module. This way, we use the lightning module as a system, as recommended by the docs.
The LightningModule.setup
could call a nn.Module.setup
defined by the user to avoid this
also runs into complications with loading checkpoints
This is a fair point. However, at least the docs for this feature should display both options and mention their differences.
🚀 Feature
Support a manual parallelization option
Motivation
Now that the Strategy refactor is complete, this unlocks a step change for research flexibility. Users no longer have to override 2 different classes (TrainingTypePlugin & Accelerator) to be able to implement custom parallelism handling, which widens the set of use cases Lightning can support as a training loop framework.
There are users who have highly customized parallelization requirements.
For instance:
Model parallelism variants: Users of plain PyTorch can partition parameters across devices. However, Lightning so far has not allowed this: distributed training forces some sort of data-parallel variant. Lightning natively doesn't support open-ended model parallelism as the nn.Modules inside of the LightningModule are a black-box to the Trainer (this is necessary for generality of the Trainer). Example use case: recommendation models often have large embedding tables that cannot fit on a single device. The sharding of these tables is highly customized. There's no single-device alternative for training these models, so the modeling logic is written in a way that assumes distributed training.
Some users may want to combine different module wrappers, such as wrap parts of their models in DDP and partition other parts with techniques like FSDP.
Variable batch sizes that require normalizing the reduced gradient by the number of samples across all batches used in that step, rather than the DDP default of normalizing by world size.
Rather than require each of these users to learn all about the Strategy codebase to be able to customize this, I propose a "manual" parallel strategy which delegates this logic back to the LightningModule.
This way, all of the modeling logic sits in one place. This is easier for researchers to get started without needing to learn another abstraction. If these techniques pan out to be more general, they can be abstracted out to fit into the Strategy interface, which makes them shareable across projects.
In this setting, the user assumes responsibility for the following:
The Trainer/Strategy will still handle:
This is intended for power users who know exactly what they're doing. The terminology manual parallel follows the precedent of manual optimization: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#manual-optimization
This is also the motivation for the PRs removing dependencies on
LightningModule.device
within the Trainer:LightningModule.device
is not properly defined for use cases where the LightningModule's parameters sit on multiple devices. This proposal aims to remove the requirement for users of these LightningModules to callLightningModule.to(...)
before executing a Trainer function.Pitch
Define a new strategy class like this:
Example of a LightningModule which is inherently distributed aware
Alternatives
Additional context
Idea for manual parallelization was also raised here: https://github.com/PyTorchLightning/pytorch-lightning/issues/8722#issuecomment-922699686
If you enjoy Lightning, check out our other projects! âš¡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @borda @awaelchli @rohitgr7 @akihironitta