Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.47k stars 3.39k forks source link

Add something like `use_compile` parameter for Trainer #20242

Open mieshkiwrk opened 2 months ago

mieshkiwrk commented 2 months ago

Description & Motivation

For below example, model is being compiled, DDPStrategy is passed to Trainer, then during fit method DDPStrategy is being applied, so forward is compiled but _pre_forward/_post_forward in DDP class is not. Due to this in DDP _pre_forward/_post_forward cpp_reducer is not being disabled later on causing problem with queueing callback. When DDP is also compiled cpp_reducer is disabled as expected.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy
import functools
from torch._dynamo import compiled_autograd

torch._dynamo.config.optimize_ddp = "python_reducer"

class SomeModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 1)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        return torch.mean((self(x) - y) ** 2)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

def create_dataset(num_samples=1000):
    x = torch.randn(num_samples, 10)
    y = torch.sum(x, dim=1, keepdim=True)
    return TensorDataset(x, y)

def run_training():
    dataset = create_dataset()
    train_loader = DataLoader(dataset, batch_size=32)

    model = SomeModel()

    # First compile whole model
    model = torch.compile(model)

    # static_graph has to be true causing _DDPSink.backward method queueing callback
    ddp_strategy = DDPStrategy(static_graph=True)

    trainer = pl.Trainer(
            max_epochs=2,
            accelerator='cpu',
            devices=1,
            strategy=ddp_strategy
    )

    # DDP will be applied inside fit method, so DDP pre/post forward won't be compiled while forward is
    with compiled_autograd.enable(torch.compile()):
        trainer.fit(model, train_loader)

if __name__ == "__main__":
    run_training()

Expected repro:

[rank0]:     Variable._execution_engine.queue_callback(  # type: ignore[call-arg,misc]
[rank0]: RuntimeError: Final callbacks can only be installed during backward pass. 

Pitch

It seems useful to compile after applying strategy, so my suggestion is to add something like bool use_compile parameter for Trainer which would help for example in this situation, and also be cleaner to use. Looks like it should be more advanced than just bool to setup specific backend and other optional compile parameters.

Alternatives

src/lightning/pytorch/trainer/trainer.py

class Trainer:
    def _run(
        self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
    ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
        (...)
        # ----------------------------
        # SET UP THE TRAINER
        # ----------------------------
        (...)
        self.strategy.setup(self)
        (...)

------> ### Pseudo proposition
------> if self.use_compile
------>     self.model = torch.compile(self.model)

        # ----------------------------
        # RUN THE TRAINER
        # ----------------------------
        results = self._run_stage()

        (...)

Additional context

File: torch/nn/parallel/distributed.py

class DistributedDataParallel(Module, Joinable): 
    def _should_disable_cpp_reducer(self) -> bool: 
        return self._use_python_reducer and ( 
            torch._utils.is_compiling() or self._force_to_disable_cpp_reducer 
        )

    def _pre_forward(self, *inputs, **kwargs):        
        if self._should_disable_cpp_reducer():                  
            return inputs, kwargs
        (...)

    def _post_forward(self, output): 
        if self._should_disable_cpp_reducer(): 
            return output
        (...)

Dynamo replaces output of torch._utils.is_compiling() to True when code is compiled, False otherwise.

cc @borda

mieshkiwrk commented 2 months ago

cc @jerome-habana