lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
2.93k stars 253 forks source link

Use structured configs to enforce types in command-line tool #586

Closed philippmwirth closed 1 year ago

philippmwirth commented 2 years ago

Use structured configs to enforce types in command-line tool

Using structured configs is a better solution for #575. However, it requires dataclasses which is only available from Python 3.7. There is a backport for dataclasses so we could also try this but we should be careful and test thoroughly.

Another option would be to "hack" something together with different config files for Python 3.6 and Python >= 3.7.

Docs for structured configs with hydra: https://hydra.cc/docs/tutorials/structured_config/schema/

For my experiments I used the following code which enforced types for the download_cli:

lightly/cli/config/config.yaml:


#defaults:
# - base_config

### i/o
# The following arguments specify input and output locations
# of images, embeddings, and checkpoints.
input_dir: ''                 # Path to input directory which holds images.
output_dir: ''                # Path to directory which should store downloads.
embeddings: ''                # Path to csv file which holds embeddings.
checkpoint: ''                # Path to a model checkpoint. If left empty, a pre-trained model
                              # will be used.
label_dir: ''                 # Path to the input directory which holds the labels.
label_names_file: ''          # Path to a yaml file having the label names under the value 'names'
custom_metadata: ''           # Path to a json file in COCO format containing additional metadata

### Lightly platform
# The following arguments are required for requests to the
# Lightly platform.
token: ''                     # User access token to the Lightly platform.
dataset_id: ''                # Identifier of the dataset on the Lightly platform.
new_dataset_name: ''          # Name of the new dataset to be created on the Lightly platform
upload: 'full'                # Whether to upload full images, thumbnails only, or metadata only.
                              # Must be one of ['full', 'thumbnails', 'none']
resize: -1                    # Allow resizing of the images before uploading, usage =-1, =x
embedding_name: 'default'     # Name of the embedding to be used on the Lightly platform.
emb_upload_bsz: 32            # Number of embeddings which are uploaded in a single batch.
tag_name: 'initial-tag'       # Name of the requested tag on the Lightly platform.
exclude_parent_tag: False     # If true, only the samples in the defined tag, but without the parent tag, are taken.

### training and embeddings
pre_trained: True             # Whether to use a pre-trained model or not
crop_padding: 0.1             # The padding to use when cropping

# model namespace: Passed to lightly.models.ResNetGenerator.
model:
  name: 'resnet-18'           # Name of the model, currently supports popular variants:
                              # resnet-18, resnet-34, resnet-50, resnet-101, resnet-152.
  out_dim: 128                # Dimensionality of output on which self-supervised loss is calculated.
  num_ftrs: 32                # Dimensionality of feature vectors (embedding size).
  width: 1                    # Width of the resnet.

# criterion namespace: Passed to lightly.loss.NTXentLoss.
criterion:            
  temperature: 0.5            # Number by which logits are divided.
  memory_bank_size: 0         # Size of the memory bank to use (e.g. for MoCo). 0 means no memory bank.
                              # ^ slight abuse of notation, MoCo paper calls it momentum encoder

# optimizer namespace: Passed to torch.optim.SGD.
optimizer:
  lr: 1.                      # Learning rate of the optimizer.
  weight_decay: 0.00001       # L2 penalty.

# collate namespace: Passed to lightly.data.ImageCollateFunction.
collate:
  input_size: 64              # Size of the input images in pixels.
  cj_prob: 0.8                # Probability that color jitter is applied.
  cj_bright: 0.7              # Color_jitter intensity for brightness,
  cj_contrast: 0.7            # contrast,
  cj_sat: 0.7                 # saturation,
  cj_hue: 0.2                 # and hue.
  min_scale: 0.15             # Minimum size of random crop relative to input_size.
  random_gray_scale: 0.2      # Probability of converting image to gray scale.
  gaussian_blur: 0.5          # Probability of Gaussian blur.
  kernel_size: 0.1            # Kernel size of gaussian blur relative to input_size.
  vf_prob: 0.0                # Probability that vertical flip is applied.
  hf_prob: 0.5                # Probability that horizontal flip is applied.
  rr_prob: 0.0                # Probability that random (+-90 degree) rotation is applied.

# loader namespace: Passed to torch.utils.data.DataLoader.
loader:
  batch_size: 16              # Batch size for training / inference.
  shuffle: True               # Whether to reshuffle data each epoch.
  num_workers: -1             # Number of workers pre-fetching batches (-1 == number of available cores).
  drop_last: True             # Wether to drop the last batch during training.

# trainer namespace: Passed to pytorch_lightning.Trainer.
trainer:
  gpus: 1                     # Number of gpus to use for training.
  max_epochs: 100             # Number of epochs to train for.
  precision: 32               # If set to 16, will use half-precision.
  weights_summary: 'top'      # how to print the model architecture, one of {None, top, full},
                                #see https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#weights-summary

# checkpoint_callback namespace: Modify the checkpoint callback
checkpoint_callback:
  save_last: True             # Whether to save the checkpoint from the last epoch.
  save_top_k: 1               # Save the top k checkpoints.
  dirpath:                    # Where to store the checkpoints (empty field resolves to None).
                              # If not set, checkpoints are stored in the hydra output dir.

# seed
seed: 1

### hydra
# The arguments below are built-ins from the hydra-core Python package.
hydra:
  run:
    dir: lightly_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
  help:
    header: |
      == Description ==
      The lightly Python package is a command-line tool for self-supervised learning.

    footer: |
      == Examples ==

      Use a pre-trained resnet-18 to embed your images
      > lightly-embed input='path/to/image/folder' collate.input_size=224

      Load a model from a custom checkpoint to embed your images
      > lightly-embed input_dir='path/to/image/folder' collate.input_size=224 checkpoint='path/to/checkpoint.ckpt'

      Train a self-supervised model on your image dataset from scratch
      > lightly-train input_dir='path/to/image/folder' loader.batch_size=128 collate.input_size=224 pre_trained=False

      Train a self-supervised model starting from the pre-trained checkpoint
      > lightly-train input_dir='path/to/image/folder' loader.batch_size=128 collate.input_size=224

      Train a self-supervised model starting from a custom checkpoint
      > lightly-train input_dir='path/to/image/folder' loader.batch_size=128 collate.input_size=224 checkpoint='path/to/checkpoint.ckpt'

      Train using half-precision
      > lightly-train input_dir='path/to/image/folder' trainer.precision=16

      Upload thumbnails to the Lightly web solution
      > lightly-upload input_dir='path/to/image/folder' dataset_id='your_dataset_id' token='your_access_token'

      Upload only metadata of the images to the Lightly web solution
      > lightly-upload input_dir='path/to/image/folder' dataset_id='your_dataset_id' token='your_access_token upload='metadata'

      Upload full images to the Lightly web solution
      > lightly-upload input_dir='path/to/image/folder' dataset_id='your_dataset_id' token='your_access_token' upload='full'

      Upload images and embeddings to the Lightly web solution
      > lightly-upload input_dir='path/to/image/folder' embeddings='path/to/embeddings.csv' dataset_id='your_dataset_id' token='your_access_token'

      Upload embeddings to the Lightly web solution
      > lightly-upload embeddings='path/to/embeddings.csv' dataset_id='your_dataset_id' token='your_access_token'

      Download a list of files in a given tag from the Lightly web solution
      > lightly-download tag_name='my-tag' dataset_id='your_dataset_id' token='your_access_token'

      Download a list of files in a given tag without filenames from the parent tag from the Lightly web solution
      > lightly-download tag_name='my-tag' dataset_id='your_dataset_id' token='your_access_token' exclude_parent_tag=True

      Copy all files in a given tag from a source directory to a target directory
      > lightly-download tag_name='my-tag' dataset_id='your_dataset_id' token='your_access_token' input_dir='data/' output_dir='new_data/'

      == Additional Information ==

      Use self-supervised methods to understand and filter raw image data:

      Website: https://www.lightly.ai
      Documentation: https://docs.lightly.ai

lightly/cli/config/config.py:

from typing import Optional
from hydra.core.config_store import ConfigStore

# dataclasses are only available from Python 3.7 onwards
from dataclasses import dataclass

@dataclass
class ModelConfig:
    name: str
    out_dim: int
    num_ftrs: int
    width: int

@dataclass
class CriterionConfig:
    temperature: float
    memory_bank_size: int

@dataclass
class OptimizerConfig:
    lr: float
    weight_decay: float

@dataclass
class CollateConfig:
    input_size: int
    cj_prob: float
    cj_bright: float
    cj_contrast: float
    cj_sat: float
    cj_hue: float
    min_scale: float
    random_gray_scale: float
    gaussian_blur: float
    kernel_size: float
    vf_prob: float
    hf_prob: float
    rr_prob: float

@dataclass
class LoaderConfig:
    batch_size: int
    shuffle: bool
    num_workers: int
    drop_last: bool

@dataclass
class TrainerConfig:
    gpus: int
    max_epochs: int
    precision: int
    weights_summary: Optional[str]

@dataclass
class CheckpointCallbackConfig:
    save_last: bool
    save_top_k: int
    dirpath: Optional[str]

@dataclass
class Config:

    model: ModelConfig
    criterion: CriterionConfig
    optimizer: OptimizerConfig
    collate: CollateConfig
    loader: LoaderConfig
    trainer: TrainerConfig
    checkpoint_callback: CheckpointCallbackConfig

    tag_name: str
    resize: int

    input_dir: str
    output_dir: str
    embeddings: str
    checkpoint: str

    label_dir: str
    label_names_file: str
    custom_metadata: str

    token: str
    dataset_id: str
    new_dataset_name: str
    upload: str

    embedding_name: str
    emb_upload_bsz: int

    exclude_parent_tag: bool

    pre_trained: bool
    crop_padding: float

    seed: int = 1

cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)

In lightly/cli/download_cli.py:

from lightly.cli.config.config import Config

@hydra.main(config_path='config', config_name='config')
def download_cli(cfg: Config):
    """Download images from the Lightly platform.

    Args:
        cfg:
            The default configs are loaded from the config file.
            To overwrite them please see the section on the config file 
            (.config.config.yaml).

    Command-Line Args:
        tag_name:
            Download all images from the requested tag. Use initial-tag
            to get all images from the dataset.
        token:
            User access token to the Lightly platform. If dataset_id
            and token are specified, the images and embeddings are 
            uploaded to the platform.
        dataset_id:
            Identifier of the dataset on the Lightly platform. If 
            dataset_id and token are specified, the images and 
            embeddings are uploaded to the platform.
        input_dir:
            If input_dir and output_dir are specified, lightly will copy
            all images belonging to the tag from the input_dir to the 
            output_dir.
        output_dir:
            If input_dir and output_dir are specified, lightly will copy
            all images belonging to the tag from the input_dir to the 
            output_dir.

    Examples:
        >>> # download list of all files in the dataset from the Lightly platform
        >>> lightly-download token='123' dataset_id='XYZ'
        >>> 
        >>> # download list of all files in tag 'my-tag' from the Lightly platform
        >>> lightly-download token='123' dataset_id='XYZ' tag_name='my-tag'
        >>>
        >>> # download all images in tag 'my-tag' from the Lightly platform
        >>> lightly-download token='123' dataset_id='XYZ' tag_name='my-tag' output_dir='my_data/'
        >>>
        >>> # copy all files in 'my-tag' to a new directory
        >>> lightly-download token='123' dataset_id='XYZ' tag_name='my-tag' input_dir='data/' output_dir='my_data/'

    """
    _download_cli(cfg)
MalteEbner commented 2 years ago

It might not be a big problem, but I found that we don't have a CLI keyword for uploading metadata . Instead, if upload is neither thumbnails nor full , then there is the metadata upload. This has lead to us using very different keywords for metadata upload. In some documentation it is called metadata, in other meta, in other None . Should we try to make this consistent and enforce a single valid keyword?

guarin commented 1 year ago

Not planned.