Open samvanstroud opened 1 year ago
As a beginner running small models, torch.compile runs 1.5x faster than non-compile version. So I think torch.compile is pretty useful for us.
There's 2 ways to do this right now:
cli = LightningCLI(run=False)
compiled_model = torch.compile(cli.model)
cli.trainer.fit(compiled_model)
with python script.py
OR
class TorchCompileCLI(LightningCLI):
def fit(self, model, **kwargs):
compiled_model = torch.compile(model)
self.trainer.fit(compiled_model, **kwargs)
TorchCompileCLI()
with python script.py fit
Currently this segfaults for me right before sanity checking the dataloaders, but the standalone model is compilable o.O
Currently this segfaults for me right before sanity checking the dataloaders, but the standalone model is compilable o.O
This might be related to # https://github.com/pytorch/pytorch/issues/107960#issuecomment-1709589190
It seems like another way would be to use something like before_fit
to then do self.model = torch.compile(self.model)
in a custom CLI. But then one would need to implement all of the before_*
functions to handle any subcommand.
@carmocca is there a way to do this that would work regardless of the subcommand being used? For example, maybe there's something between instantiate_classes()
and _run_subcommand
that is intended to be overridden to support running code after self.model
has been instantiated, but before the subcommand runs?
@carmocca is there a way to do this that would work regardless of the subcommand being used? For example, maybe there's something between
instantiate_classes()
and_run_subcommand
that is intended to be overridden to support running code afterself.model
has been instantiated, but before the subcommand runs?
@mukhery you could implement a custom instantiator for LightningModule
classes that compiles right before returning the class instance. See add_instantiator.
I don't think this is the best way, but this is what I end up with.
class MyLightningModule(LightningModule):
def __init__(
self,
model,
num_class: int,
input_key: str = "rgb",
label_key: str = "label",
label_smoothing: float = 0.1,
compile_mode: Optional[Literal["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]] = None,
):
super(MyLightningModule, self).__init__()
self.label_smoothing = label_smoothing
self.compile_mode = compile_mode
self.num_class = num_class
self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
self.accuracy = Accuracy(task="multiclass", num_classes=num_class)
self.model = model
self.input_key = input_key
self.label_key = label_key
if self.compile_mode is not None:
print("compile_mode", self.compile_mode)
self.model_compiled = torch.compile(model, mode=self.compile_mode)
print("Model compiled.")
def forward(self, x):
if self.compile_mode:
return self.model_compiled(x)
return self.model(x)
def training_step(self, batch, batch_idx):
labels = batch[self.label_key]
inputs = batch[self.input_key]
predictions = self(inputs)
loss = self.criterion(predictions, labels)
acc = self.accuracy(predictions, labels)
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", acc, prog_bar=True)
return loss
def eval_step(self, batch, batch_idx, prefix: str):
labels = batch[self.label_key]
inputs = batch[self.input_key]
predictions = self(inputs)
loss = self.criterion(predictions, labels)
acc = self.accuracy(predictions, labels)
self.log(f"{prefix}_loss", loss, prog_bar=True)
self.log(f"{prefix}_acc", acc, prog_bar=True)
return {f"{prefix}_loss": loss, f"{prefix}_acc": acc}
def validation_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx, "test")
def on_save_checkpoint(self, checkpoint):
# do not want to save the compiled model state dict, as compile results may differ per device
# it will still save the original model state dict
# c.f.: https://discuss.pytorch.org/t/how-to-save-load-a-model-with-torch-compile/179739/2
for key in list(checkpoint["state_dict"].keys()):
if key.startswith("model_compiled."):
del checkpoint["state_dict"][key]
then i can do python main.py fit --model.compile_mode default
Description & Motivation
As discussed in https://github.com/Lightning-AI/lightning/issues/15894, it would be nice to ask the trainer to compile the
LightningModule
with acompile=True
flag or similar.Since that issue was closed I'm opening this to request again this feature.
Pitch
No response
Alternatives
No response
Additional context
No response
cc @borda @carmocca @mauvilsa