Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.2k stars 3.37k forks source link

Lightning 1.5 auto adds "fit" key to top level of CLI config breaking jsonargparse configuration check #10460

Closed rbracco closed 2 years ago

rbracco commented 2 years ago

πŸ› Bug

Since the CLI subcommands were added in 1.5 my config files break after running because a new top level key "fit" is added. I run the command python train_scripts/trainer.py fit --config config.yaml as recommended in the docs and this runs properly, but the config.yaml is overwritten and goes from the format

trainer:
   gpus: 1
data:
   num_workers: 8

to having it all nested under the 'fit' key as follows:

fit:
   trainer:
       gpus: 1
   data:
       num_workers: 8

then when I run python train_scripts/trainer.py fit --config config.yaml again it fails with jsonargparse error message: trainer.py: error: Configuration check failed :: Key "data.val_manifest" is required but not included in config object or its value is None., presumably because that value is now found under "fit.data.val_manifest".

To Reproduce

Expected behavior

Running python train_scripts/trainer.py fit --config config.yaml should not clobber the config file. Expected behavior in the docs is that config files that come after the subcommand should not need to have top-level subcommand keys in the config

image

Environment

carmocca commented 2 years ago

Mind sharing your LightningCLI implementation and how you instantiate it?

rbracco commented 2 years ago

Sure, I've included it below. Thanks for taking a look.

class FinetuneCLI(LightningCLI):
    def add_arguments_to_parser(self, parser) -> None:
        parser.add_argument("--loss_function")
        parser.add_argument("--decoder")
        parser.add_lightning_class_args(FinetuneEncoderDecoder, "finetuner")
        parser.set_defaults(
            {
                "trainer.max_epochs": 10,
                "trainer.gpus": -1,
                "trainer.log_every_n_steps": 10,
                "trainer.flush_logs_every_n_steps": 50,
            }
        )
        return super().add_arguments_to_parser(parser)

    def before_fit(self):
        self.model.loss_function = getattr(loss_functions, self.config["loss_function"])
        DecoderClass = getattr(blocks, self.config["decoder"])
        num_classes = len(self.model.text_transform.vocab)
        model_arch = str(type(self.model))
        self.model.decoder = DecoderClass(1024, num_classes)

def get_quartznet(
    checkpoint: str = "QuartzNet15x5Base_En",
    learning_rate: float = 3e-4,
    labels: list = None,
) -> QuartznetModule:
    checkpoint = QuartznetCheckpoint.from_string(checkpoint)
    module = QNPronounceModule.load_from_nemo(checkpoint)
    module.hparams.optim_cfg.learning_rate = learning_rate
    text_config = TextTransformConfig(labels)
    module.change_vocab(text_config)
    return module

FinetuneCLI(
    model_class=get_quartznet,
    datamodule_class=PronounceDatamodule,
    save_config_overwrite=True,
)
carmocca commented 2 years ago

So you are saying that running

python script.py fit --print_config > config.yaml

saves the correct config, which can be used later with

python script.py fit --config config.yaml

but at some point config.yaml gets overwritten with the config format that includes the subcommand.

If that's the case, I haven't been able to reproduce it with this script

import torch
from torch.utils.data import Dataset, DataLoader

from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.cli import LightningCLI

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

class FinetuneCLI(LightningCLI):
    def add_arguments_to_parser(self, parser) -> None:
        parser.set_defaults({"trainer.max_epochs": 10})

    def before_fit(self):
        ...

FinetuneCLI(
    model_class=BoringModel,
    save_config_overwrite=True,
)
rbracco commented 2 years ago

So you are saying that running

python script.py fit --print_config > config.yaml

saves the correct config, which can be used later with

python script.py fit --config config.yaml

but at some point config.yaml gets overwritten with the config format that includes the subcommand.

Looking at this I think the issue may be that I was using a config file I maintained from a previous version of lightning and the rules changed in 1.5. I already rolled back and have everything working but eventually I will upgrade again and reopen if the issue persists, or update confirming that the problem was the initial config formatting. Thanks for looking in to this.