Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.04k stars 3.36k forks source link

torch.compile support from lightning CLI #17283

Open samvanstroud opened 1 year ago

samvanstroud commented 1 year ago

Description & Motivation

As discussed in https://github.com/Lightning-AI/lightning/issues/15894, it would be nice to ask the trainer to compile the LightningModule with a compile=True flag or similar.

Since that issue was closed I'm opening this to request again this feature.

Pitch

No response

Alternatives

No response

Additional context

No response

cc @borda @carmocca @mauvilsa

nanoric commented 1 year ago

As a beginner running small models, torch.compile runs 1.5x faster than non-compile version. So I think torch.compile is pretty useful for us.

carmocca commented 1 year ago

There's 2 ways to do this right now:

cli = LightningCLI(run=False)
compiled_model = torch.compile(cli.model)
cli.trainer.fit(compiled_model)

with python script.py

OR

class TorchCompileCLI(LightningCLI):
    def fit(self, model, **kwargs):
        compiled_model = torch.compile(model)
        self.trainer.fit(compiled_model, **kwargs)

TorchCompileCLI()

with python script.py fit

maxfreu commented 1 year ago

Currently this segfaults for me right before sanity checking the dataloaders, but the standalone model is compilable o.O

tkella47 commented 9 months ago

Currently this segfaults for me right before sanity checking the dataloaders, but the standalone model is compilable o.O

This might be related to # https://github.com/pytorch/pytorch/issues/107960#issuecomment-1709589190

mukhery commented 8 months ago

It seems like another way would be to use something like before_fit to then do self.model = torch.compile(self.model) in a custom CLI. But then one would need to implement all of the before_* functions to handle any subcommand.

@carmocca is there a way to do this that would work regardless of the subcommand being used? For example, maybe there's something between instantiate_classes() and _run_subcommand that is intended to be overridden to support running code after self.model has been instantiated, but before the subcommand runs?

mauvilsa commented 8 months ago

@carmocca is there a way to do this that would work regardless of the subcommand being used? For example, maybe there's something between instantiate_classes() and _run_subcommand that is intended to be overridden to support running code after self.model has been instantiated, but before the subcommand runs?

@mukhery you could implement a custom instantiator for LightningModule classes that compiles right before returning the class instance. See add_instantiator.

apple2373 commented 4 days ago

I don't think this is the best way, but this is what I end up with.

class MyLightningModule(LightningModule):
    def __init__(
        self,
        model,
        num_class: int,
        input_key: str = "rgb",
        label_key: str = "label",
        label_smoothing: float = 0.1,
        compile_mode: Optional[Literal["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]] = None,
    ):
        super(MyLightningModule, self).__init__()
        self.label_smoothing = label_smoothing
        self.compile_mode = compile_mode
        self.num_class = num_class
        self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
        self.accuracy = Accuracy(task="multiclass", num_classes=num_class)
        self.model = model
        self.input_key = input_key
        self.label_key = label_key
        if self.compile_mode is not None:
            print("compile_mode", self.compile_mode)
            self.model_compiled = torch.compile(model, mode=self.compile_mode)
            print("Model compiled.")

    def forward(self, x):
        if self.compile_mode:
            return self.model_compiled(x)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        labels = batch[self.label_key]
        inputs = batch[self.input_key]
        predictions = self(inputs)
        loss = self.criterion(predictions, labels)
        acc = self.accuracy(predictions, labels)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def eval_step(self, batch, batch_idx, prefix: str):
        labels = batch[self.label_key]
        inputs = batch[self.input_key]
        predictions = self(inputs)
        loss = self.criterion(predictions, labels)
        acc = self.accuracy(predictions, labels)
        self.log(f"{prefix}_loss", loss, prog_bar=True)
        self.log(f"{prefix}_acc", acc, prog_bar=True)
        return {f"{prefix}_loss": loss, f"{prefix}_acc": acc}

    def validation_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx, "test")

    def on_save_checkpoint(self, checkpoint):
        # do not want to save the compiled model state dict, as compile results may differ per device
        # it will still  save the original model state dict
        # c.f.: https://discuss.pytorch.org/t/how-to-save-load-a-model-with-torch-compile/179739/2
        for key in list(checkpoint["state_dict"].keys()):
            if key.startswith("model_compiled."):
                del checkpoint["state_dict"][key]

then i can do python main.py fit --model.compile_mode default