nmichlo / disent

🧢 Modular VAE disentanglement framework for python built with PyTorch Lightning β–Έ Including metrics and datasets β–Έ With strongly supervised, weakly supervised and unsupervised methods β–Έ Easily configured and run with Hydra config β–Έ Inspired by disentanglement_lib
https://disent.michlo.dev
MIT License
122 stars 18 forks source link
autoencoders configurable datasets disentangled-representations disentanglement metric-learning metrics python python3 pytorch pytorch-lightning representation-learning vae

🧢 Disent

A modular disentangled representation learning framework built with PyTorch Lightning

license python versions pypi version tests status Code style: black Imports: isort

Visit the docs for more info, or browse the releases.

Contributions are welcome!

────────────────
NOTE: My MSc. research has moved here
Some of the contributions have been incorporated directly into disent
────────────────


Table Of Contents


Overview

Disent is a modular disentangled representation learning framework for auto-encoders, built upon PyTorch-Lightning. This framework consists of various composable components that can be used to build and benchmark various disentanglement vision tasks.

The name of the framework is derived from both disentanglement and scientific dissent.

Get started with disent by installing it with $pip install disent or cloning this repository.

Goals

Disent aims to fill the following criteria:

  1. Provide high quality, readable, consistent and easily comparable implementations of frameworks
  2. Highlight difference between framework implementations by overriding hooks and minimising duplicate code
  3. Use best practice eg. torch.distributions
  4. Be extremely flexible & configurable
  5. Support low memory systems

Citing Disent

Please use the following citation if you use Disent in your own research:

@Misc{Michlo2021Disent,
  author =       {Nathan Juraj Michlo},
  title =        {Disent - A modular disentangled representation learning framework for pytorch},
  howpublished = {Github},
  year =         {2021},
  url =          {https://github.com/nmichlo/disent}
}

Features

Disent includes implementations of modules, metrics and datasets from various papers.

Note that "🧡" means that the dataset, framework or metric was introduced by disent!

Datasets

Various common datasets used in disentanglement research are included with disent. The dataset loaders provide various features including:

Data input and target dataset augmentations and transforms are supported, as well as augmentations on the GPU or CPU at different points in the pipeline.

Frameworks

Disent provides the following Auto-Encoders and Variational Auto-Encoders!

Introduced in Disent

πŸ— Todo: Many popular disentanglement frameworks still need to be added, please submit an issue if you have a request for an additional framework.

+ FactorVAE + GroupVAE + MLVAE

Metrics

Various metrics are provided by disent that can be used to evaluate the learnt representations of models that have been trained on ground-truth data.

πŸ— Todo: Some popular metrics still need to be added, please submit an issue if you wish to add your own, or you have a request.

+ [DCIMIG](https://arxiv.org/abs/1910.05587) + [Modularity and Explicitness](https://arxiv.org/abs/1802.05312)

Schedules & Annealing

Hyper-parameter annealing is supported through the use of schedules. The currently implemented schedules include:


Architecture

The disent module structure:

⚠️ The API Is Mostly Stable ⚠️

Disent is still under development. Features and APIs are subject to change! However, I will try and minimise the impact of these.

A small suite of tests currently exist which will be expanded upon in time.

Hydra Experiment Directories

Easily run experiments with hydra config, these files are not available from pip install.

Extending The Default Configs

All configs in experiment/config can easily be extended or overridden without modifying any files. We can add a new config folder to the hydra search path by setting the environment variable DISENT_CONFIGS_PREPEND to point to a config folder that should take priority over those contained in the default folder.

The advantage of this is that new frameworks and datasets can be used with experiments without cloning or modifying disent itself. You can separate your research code from the library!


Examples

Python Example

The following is a basic working example of disent that trains a BetaVAE with a cyclic beta schedule and evaluates the trained model with various metrics.

πŸ’Ύ Basic Example

```python3 import lightning as L import torch from torch.utils.data import DataLoader from disent.dataset import DisentDataset from disent.dataset.data import XYObjectData from disent.dataset.sampling import SingleSampler from disent.dataset.transform import ToImgTensorF32 from disent.frameworks.vae import BetaVae from disent.metrics import metric_dci from disent.metrics import metric_mig from disent.model import AutoEncoder from disent.model.ae import DecoderConv64 from disent.model.ae import EncoderConv64 from disent.schedule import CyclicSchedule # create the dataset & dataloaders # - ToImgTensorF32 transforms images from numpy arrays to tensors and performs checks # - if you use `num_workers != 0` in the DataLoader, the make sure to # wrap `trainer.fit` with `if __name__ == '__main__': ...` data = XYObjectData() dataset = DisentDataset(dataset=data, sampler=SingleSampler(), transform=ToImgTensorF32()) dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=True, num_workers=0) # create the BetaVAE model # - adjusting the beta, learning rate, and representation size. module = BetaVae( model=AutoEncoder( # z_multiplier is needed to output mu & logvar when parameterising normal distribution encoder=EncoderConv64(x_shape=data.x_shape, z_size=10, z_multiplier=2), decoder=DecoderConv64(x_shape=data.x_shape, z_size=10), ), cfg=BetaVae.cfg( optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum', beta=4, ) ) # cyclic schedule for target 'beta' in the config/cfg. The initial value from the # config is saved and multiplied by the ratio from the schedule on each step. # - based on: https://arxiv.org/abs/1903.10145 module.register_schedule( 'beta', CyclicSchedule( period=1024, # repeat every: trainer.global_step % period ) ) # train model # - for 2048 batches/steps trainer = L.Trainer( max_steps=2048, gpus=1 if torch.cuda.is_available() else None, logger=False, enable_checkpointing=False ) trainer.fit(module, dataloader) # compute disentanglement metrics # - we cannot guarantee which device the representation is on # - this will take a while to run get_repr = lambda x: module.encode(x.to(module.device)) metrics = { **metric_dci(dataset, get_repr, num_train=1000, num_test=500, show_progress=True), **metric_mig(dataset, get_repr, num_train=2000), } # evaluate print('metrics:', metrics) ```

Visit the docs for more examples!

Hydra Config Example

The entrypoint for basic experiments is experiment/run.py.

Some configuration will be required, but basic experiments can be adjusted by modifying the Hydra Config 1.1 files in experiment/config.

Modifying the main experiment/config/config.yaml is all you need for most basic experiments. The main config file contains a defaults list with entries corresponding to yaml configuration files (config options) in the subfolders (config groups) in experiment/config/<config_group>/<option>.yaml.

πŸ’Ύ Config Defaults Example

```yaml defaults: # data - sampling: default__bb - dataset: xyobject - augment: none # system - framework: adavae_os - model: vae_conv64 # training - optimizer: adam - schedule: beta_cyclic - metrics: fast - run_length: short # logs - run_callbacks: vis - run_logging: wandb # runtime - run_location: local - run_launcher: local - run_action: train # ... ```

Easily modify any of these values to adjust how the basic experiment will be run. For example, change framework: adavae to framework: betavae, or change the dataset from xyobject to shapes3d. Add new options by adding new yaml files in the config group folders.

Weights and Biases is supported by changing run_logging: none to run_logging: wandb. However, you will need to login from the command line. W&B logging supports visualisations of latent traversals.


Install

pip install disent

Otherwise, to install from source we recommend using a conda virtual environment.

‡️ Install from Source ```bash # clone the repo git clone https://github.com/nmichlo/disent cd disent # create and activate the conda environment [py38,py39,py310] conda create -n disent-py310 python=3.10 conda activate disent-py310 # check that the correct python version is used which python which pip # make sure to upgrade pip pip install --upgrade pip # install minimal requirements pip install -r requirements.txt # (optional) install extra requirements # - first do the above because torch is required to compile torchsort while installing pip install -r requirements-extra.txt # (optional) install test requirements pip install -r requirements-test.txt ```

Development

Code style: black Imports: isort

Make sure to install pre-commit hooks to ensure code is automatically formatted correctly when committing or pushing changes to disent.

# install git hooks
pip install pre-commit
pre-commit install

# manually trigger all pre-commit hooks
pre-commit run --all-files

To run tests locally, make sure to install all the test and extra dependencies in your environment.

pip install -r requirements.txt
# torchsort first requires torch to be installed
pip install -r requirements-extra.txt -r requirements-test.txt

Why?