nmichlo / disent

🧶 Modular VAE disentanglement framework for python built with PyTorch Lightning ▸ Including metrics and datasets ▸ With strongly supervised, weakly supervised and unsupervised methods ▸ Easily configured and run with Hydra config ▸ Inspired by disentanglement_lib
https://disent.michlo.dev
MIT License
122 stars 18 forks source link

[DOCS]: Add Experiment Examples #26

Open nmichlo opened 2 years ago

nmichlo commented 2 years ago

Is your feature request related to a problem? Please describe. The current examples are very limited and only show how to use disent.

Describe the solution you'd like Add examples for the experiment/run.py and experiment/config for the new changes for milestone v0.4.0 with the experiment plugin system.

nmichlo commented 1 year ago

Current example lives at https://github.com/nmichlo/msc-research

To tell disent about the files, you need to set the following environment variable to an absolute path:

export DISENT_CONFIGS_PREPEND="path/to/extra/config/dir"  # for the example above, this is: `<root>/research/config`

To register your files with disent from the config referenced above, you need to add the following keys to the root of your config:

experiment:
  plugins:
    - your_plugin.submodule.register_to_disent
    # for the example above, this is `research.code.register_to_disent`, which leads to `research/code/__init__.py`

The key: your_plugin.submodule is a path to a python module which contains the function register_to_disent which when called will register all the additional classes with disent to the disent registry: https://github.com/nmichlo/msc-research/blob/main/research/code/__init__.py

For example:

import disent.registry as R

def register_to_disent():
    # register metrics
    R.METRICS.setm['flatness'] = R.LazyImport('research.code.metrics._flatness.metric_flatness')

    # groundtruth -- impl synthetic
    R.DATASETS.setm['xyblocks'] = R.LazyImport('research.code.dataset.data._groundtruth__xyblocks')

    # [AE - EXPERIMENTAL]
    R.FRAMEWORKS.setm['x__dot_ae']  = R.LazyImport('research.code.frameworks.ae._unsupervised__dotae.DataOverlapTripletAe')

    # [VAE - EXPERIMENTAL]
    R.FRAMEWORKS.setm['x__dot_vae'] = R.LazyImport('research.code.frameworks.vae._unsupervised__dotvae.DataOverlapTripletVae')

    # register kernels for loss functions
    R.KERNELS.setm.register_regex(pattern=r'^(xy1)_abs(63)$',  example='xy1_abs63',  factory_fn='research.code.dataset.transform._augment._make_xy1_abs63')