Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.47k stars 3.3k forks source link

With yaml config file for LightningCLI, `self.save_hyperparameters()` behavior abnormal #19977

Open t4rf9 opened 2 weeks ago

t4rf9 commented 2 weeks ago

Bug description

With yaml config file for LightningCLI, self.save_hyperparameters() in __init__ of the model and datamodule mistakenly saves a dict containing keys like class_path and init_args.

This problems appears in version 2.3.0, but version 2.2.5 works correctly.

What version are you seeing the problem on?

2.3.0

How to reproduce the bug

config.yaml

ckpt_path: null
seed_everything: 0
model:
  class_path: model.Model
  init_args:
    learning_rate: 1e-2
data:
  class_path: datamodule.DataModule
  init_args:
    data_dir: data
trainer:
  accelerator: auto
  strategy: auto
  devices: auto
  num_nodes: 1
  precision: null
  fast_dev_run: false
  max_epochs: 100
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: 10
  limit_test_batches: null
  limit_predict_batches: null
  logger:
    class_path: lightning.pytorch.loggers.TensorBoardLogger
    init_args:
      save_dir: lightning_logs
      name: normalized
  callbacks:
    class_path: lightning.pytorch.callbacks.ModelCheckpoint
    init_args:
      save_top_k: 5
      monitor: valid_loss
      filename: "{epoch}-{step}-{valid_loss:.8f}"
  overfit_batches: 0.0
  val_check_interval: 50
  check_val_every_n_epoch: 1
  num_sanity_val_steps: null
  log_every_n_steps: 50
  enable_checkpointing: null
  enable_progress_bar: null
  enable_model_summary: null
  accumulate_grad_batches: 1
  gradient_clip_val: null
  gradient_clip_algorithm: null
  deterministic: true
  benchmark: null
  inference_mode: true
  use_distributed_sampler: true
  profiler: null
  detect_anomaly: false
  barebones: false
  plugins: null
  sync_batchnorm: true
  reload_dataloaders_every_n_epochs: 0
  default_root_dir: null

model.py

import torch
from torch import nn
import torch.nn.functional as F
import lightning as pl

class Model(pl.LightningModule):
    def __init__(self, learning_rate: float):
        super().__init__()

        print()
        print("Model:")

        print(f"learning_rate: {learning_rate}")
        ## This outputs correctly.

        self.save_hyperparameters()

        print(self.hparams)
        ## This outputs:
        # "_instantiator": lightning.pytorch.cli.instantiate_module
        # "class_path":    model.Model
        # "init_args":     {'learning_rate': 0.01}

datamodule.py

from lightning import LightningDataModule
from torch.utils.data import DataLoader

from dataset import KaptchaDataset
from transform import Transform

class DataModule(LightningDataModule):
    def __init__(self, data_dir: str):
        super().__init__()
        self.save_hyperparameters()

        print()
        print("DataModule:")

        print(self.hparams)
        ## This outputs
        # "_instantiator": lightning.pytorch.cli.instantiate_module
        # "class_path":    datamodule.DataModule
        # "init_args":     {'data_dir': 'data'}

main.py

from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from model import Model
from datamodule import DataModule

def cli_main():
    cli = LightningCLI()

if __name__ == "__main__":
    cli_main()

Run python main.py fit --config config.yaml


### Environment

<details>
  <summary>Current environment</summary>

- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):

- PyTorch Lightning Version (e.g., 1.5.0):

- Lightning App Version (e.g., 0.5.2):

- PyTorch Version (e.g., 2.0):

- Python version (e.g., 3.9):

- OS (e.g., Linux):

- CUDA/cuDNN version:

- GPU models and configuration:

- How you installed Lightning(conda, pip, source):

- Running environment of LightningApp (e.g. local, cloud):



</details>
adamjstewart commented 2 weeks ago

We're seeing this too, this broke all of TorchGeo's tests: https://github.com/microsoft/torchgeo/actions/runs/9522133755/job/26251028463?pr=2119

EthanMarx commented 2 weeks ago

+1

CarlosGomes98 commented 1 week ago

We are seeing this as well https://github.com/IBM/terratorch/issues/26

As far as I can tell it stems from https://github.com/Lightning-AI/pytorch-lightning/pull/19771 which (inadvertedly?) affects the LightningCLI parser

adamjstewart commented 6 days ago

Still broken in 2.3.1, still preventing TorchGeo from supporting newer versions of Lightning.

adamjstewart commented 1 day ago

Still broken in 2.3.2.