Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.42k stars 3.39k forks source link

LightningCLI: config argument linkings in config file #8924

Closed quancs closed 3 years ago

quancs commented 3 years ago

🚀 Feature

Show and config linking in config file

Motivation

In current design of CLI, we need to construct one to one MyLightningCLIs, if we have linking arguments in our models. Provide this feature can save some time in doing this for we can use the default implementation LightningCLI

Pitch

data:
  num_classes: 2
model:
  num_classes:
    link_source:  data.num_classes
    link_compute_fn:  path.to.func
    link_apply_on:  instantiate

Alternatives

CLI:
  links:
    - source: data.num_classes
      target:  model.num_classes
      compute_fn:  path.to.func
      apply_on:  instantiate
data:
  num_classes: 2
model:
  num_classes: null

or for simple value linking:

data:
  num_classes: 2
model:
  num_classes: <-- data.num_classes

Additional context


If you enjoy Lightning, check out our other projects! âš¡

tchaton commented 3 years ago

@mauvilsa Any thoughts for json-arparse ?

tchaton commented 3 years ago

Dear @quancs,

I believe this feature request should be open on json-argparse directly :)

Best, T.C

mauvilsa commented 3 years ago

I explain a bit how the linking of arguments feature came to be and why I don't like this proposal. The objective of config files should be only allow to configure stuff, otherwise there is a blurred line between what is source code and what is configuration. The linking of arguments is intended for cases in which always some value is derived from configurable values, i.e. the linked argument is not configurable. This is why when an argument is linked, it is no longer specifiable in the config file. In other words, linked arguments is not configuration and it belongs in the source code. Not wanting to subclass LightningCLI is not a good reason to add complexity to the config files.

@quancs could you please explain your use case which lead you to request this feature. There might be something that could belong in the config file. Or we could come up with a way to achieve what you need without adding complexity to the config.

quancs commented 3 years ago

In the development of my model, I will try plenty of ideas, which results in plenty of model implementations, and for each implementation I need imp its own CLI. Or is there some way, I can share their CLI implementations? Or, the best practice is not to share CLIs?

Another reason for this feature is that I can clearly figure out where the shared parameters comes from in the config file. In the current design, we could not even know from the config file that there is a parameter that is needed for constructing the linking target. I guess that would add a little bit more complexity if we share our model with others, especially with the guys who are not familiar with lightning. Example from speechbrain (they also provide a yaml based config file). From this example, we could clearly know: all the paramters needed to construct the objects and which parameter is shared from which place. That is clearer for me. @mauvilsa what do you think?

quancs commented 3 years ago

HyperPyYAML, the yaml package used in speechbrain. It's functionality is very novel and useful. And the grammer is simple too.

mauvilsa commented 3 years ago

Yaml already provides features like anchors that allow reusing parts of a config. Also there are many packages that implement interpolation on top of yaml, e.g. Ansible, and I think it would be great if optionally people would be able to switch in jsonargparse to a yaml library that supports it. However there has been a tendency of each package internally implementing interpolation without releasing a yaml library that can be reused. Great that HyperPyYAML exists standalone. Will look at it to see if it could be used to optionally swap pyyaml in jsonargparse.

However, note that my opinion remains the same. The objective of argument linking is non-configurable things. Interpolation is okay for some use cases, but not overuse it for things that should not be configurable.

@quancs I did get the idea that you wanted to try out many models. And I think is great if you have a single CLI for it. But still I don't think is a good idea to introduce into the config file the details of how to adapt different model implementations into a common interface. Interface adaptation should be in the source code. I was hoping that you give more specific details about your use case to see if there are needs I have not thought of.

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

mauvilsa commented 2 years ago

HyperPyYAML, the yaml package used in speechbrain. It's functionality is very novel and useful. And the grammer is simple too.

@quancs with jsonargparse v4.1.0 released today now it is possible to change the yaml loader making it possible to add support for variable interpolation. Out of the box you can use OmegaConf, see https://jsonargparse.readthedocs.io/en/stable/#variable-interpolation, but if you want to use HyperPyYAML simply register a loader function that uses it https://jsonargparse.readthedocs.io/en/latest/#custom-loaders.

quancs commented 2 years ago

HyperPyYAML, the yaml package used in speechbrain. It's functionality is very novel and useful. And the grammer is simple too.

@quancs with jsonargparse v4.1.0 released today now it is possible to change the yaml loader making it possible to add support for variable interpolation. Out of the box you can use OmegaConf, see https://jsonargparse.readthedocs.io/en/stable/#variable-interpolation, but if you want to use HyperPyYAML simply register a loader function that uses it https://jsonargparse.readthedocs.io/en/stable/#custom-loaders.

Great! I will try, thank you for your nice work.

quancs commented 2 years ago

@mauvilsa variable interpolation not work. run python boring.py fit --config boring.yaml

boring.py

import os
from typing import List

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer

class Arch:

    def __init__(self, input_size: int, channels: List[int], a: int = 10) -> None:
        pass

class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):

    def __init__(self, arch: Arch, channels: List[int] = [0, 1, 2, 3, 4, 5, 6, 7, 8]):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

from pytorch_lightning.utilities.cli import (LightningArgumentParser, LightningCLI)

class MyCLI(LightningCLI):

    def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
        default_arch = {
            "class_path": "__main__.Arch",
            "init_args": {
                "a": "10",
                # "channels": "${model.channels}"
            },
        }

        parser.set_defaults({
            "model.arch": default_arch,
        })

        parser.link_arguments("model.channels", "model.arch.init_args.input_size", compute_fn=lambda channels: 2 * len(channels), apply_on="parse")

        return super().add_arguments_to_parser(parser)

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
import os

class MyDataModule(LightningDataModule):

    def __init__(self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None):
        super().__init__(train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, dims=dims)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(RandomDataset(32, 64), batch_size=2)

if __name__ == '__main__':
    cli = MyCLI(BoringModel, MyDataModule, seed_everything_default=None, save_config_overwrite=True, parser_kwargs={"parser_mode": "omegaconf"})

boring.yaml

seed_everything: null
trainer:
  logger: true
  checkpoint_callback: null
  enable_checkpointing: true
  callbacks: null
  default_root_dir: null
  gradient_clip_val: null
  gradient_clip_algorithm: null
  process_position: 0
  num_nodes: 1
  num_processes: 1
  devices: null
  gpus: null
  auto_select_gpus: false
  tpu_cores: null
  ipus: null
  log_gpu_memory: null
  progress_bar_refresh_rate: null
  enable_progress_bar: true
  overfit_batches: 0.0
  track_grad_norm: -1
  check_val_every_n_epoch: 1
  fast_dev_run: false
  accumulate_grad_batches: null
  max_epochs: null
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: 1.0
  limit_val_batches: 1.0
  limit_test_batches: 1.0
  limit_predict_batches: 1.0
  val_check_interval: 1.0
  flush_logs_every_n_steps: null
  log_every_n_steps: 50
  accelerator: null
  strategy: null
  sync_batchnorm: false
  precision: 32
  enable_model_summary: true
  weights_summary: top
  weights_save_path: null
  num_sanity_val_steps: 2
  resume_from_checkpoint: null
  profiler: null
  benchmark: false
  deterministic: false
  reload_dataloaders_every_n_epochs: 0
  reload_dataloaders_every_epoch: false
  auto_lr_find: false
  replace_sampler_ddp: true
  detect_anomaly: false
  auto_scale_batch_size: false
  prepare_data_per_node: null
  plugins: null
  amp_backend: native
  amp_level: null
  move_metrics_to_cpu: false
  multiple_trainloader_mode: max_size_cycle
  stochastic_weight_avg: false
  terminate_on_nan: null
model:
  arch:
    class_path: __main__.Arch
    init_args:
      channels: "${model.channels}"
      a: 10
  channels:
  - 0
  - 1
  - 2
  - 3
  - 4
  - 5
  - 6
  - 7
  - 8
data:
  train_transforms: null
  val_transforms: null
  test_transforms: null
  dims: null
ckpt_path: null
quancs commented 2 years ago

Oh, I see. It's not on the recent release...