Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.99k stars 3.35k forks source link

Move strategy-specific dataloader logic to the stategies #11756

Open ananthsub opened 2 years ago

ananthsub commented 2 years ago

Proposed refactor

Motivation

Strategies today interact with dataloading, especially in distributed training. It makes sense for the strategy to directly handle this logic.

This would reduce and simplify interactions elsewhere in the trainer, in particular strategy state -> trainer properties -> data connector logic

And would remove the hacky patching that IPUs do right now 😧 https://github.com/PyTorchLightning/pytorch-lightning/blob/cc43d07db1ab77385feff04c01f040c5cad805a9/pytorch_lightning/strategies/ipu.py#L125-L130

Pitch

The Strategy interface already offers a process_dataloader method: https://github.com/PyTorchLightning/pytorch-lightning/blob/cc43d07db1ab77385feff04c01f040c5cad805a9/pytorch_lightning/strategies/strategy.py#L369-L375

However, there's a ton of strategy-specific logic written in the trainer's data connector:

For example, warnings with DDP spawn: https://github.com/PyTorchLightning/pytorch-lightning/blob/cc43d07db1ab77385feff04c01f040c5cad805a9/pytorch_lightning/trainer/connectors/data_connector.py#L278-L312

generic multi-processing warning: https://github.com/PyTorchLightning/pytorch-lightning/blob/cc43d07db1ab77385feff04c01f040c5cad805a9/pytorch_lightning/trainer/connectors/data_connector.py#L314-L322

the is_distributed flag is primarily used here within the trainer: https://github.com/PyTorchLightning/pytorch-lightning/blob/cc43d07db1ab77385feff04c01f040c5cad805a9/pytorch_lightning/trainer/connectors/data_connector.py#L324-L330

Distributed sampler kwargs is strategy specific: https://github.com/PyTorchLightning/pytorch-lightning/blob/cc43d07db1ab77385feff04c01f040c5cad805a9/pytorch_lightning/trainer/trainer.py#L2126-L2129

IPU check: https://github.com/PyTorchLightning/pytorch-lightning/blob/cc43d07db1ab77385feff04c01f040c5cad805a9/pytorch_lightning/trainer/connectors/data_connector.py#L364

In my opinion we could simplify this by moving relevant logic into strategy.process_dataloader instead. Common logic can still be abstracted out into utility functions to share across different strategy classes.

We could either:

def process_dataloader(self, dataloader, use: str) -> Union[dataloader, iterable]:

or offer multiple APIs that map to the DataLoader hooks: https://github.com/PyTorchLightning/pytorch-lightning/blob/cc43d07db1ab77385feff04c01f040c5cad805a9/pytorch_lightning/core/hooks.py#L406

def process_train_dataloader(self, dataloader):
def process_val_dataloader(self, dataloader):
def process_test_dataloader(self, dataloader):
def process_predict_dataloader(self, dataloader):

Over time, the Trainer flag replace_sampler_ddp makes much more sense on the specific distributed strategy constructors instead of on the trainer.

Additional context


If you enjoy Lightning, check out our other projects! âš¡

cc @justusschock @awaelchli @akihironitta @rohitgr7 @ninginthecloud @tchaton @borda

awaelchli commented 2 years ago

This is a great concept and I feel we should definitely explore it. However, I see some challenges to put everything in this single method, because the logic is spread out and interleaved with other non-strategy related functions. To get to a proof-of-concept, one could attempt an intermediate refactor to sequence out the strategy-related code and merge them into a single method inside the connector, to find what the blockers are.

It would also be great if we could make it trainer-reference-free to enable using the hook also in Lite. In this sense, I would first explore the generic def process_dataloader(self, dataloader, use: str) -> Union[dataloader, iterable]: instead of the prefixed versions.