🧶 Modular VAE disentanglement framework for python built with PyTorch Lightning ▸ Including metrics and datasets ▸ With strongly supervised, weakly supervised and unsupervised methods ▸ Easily configured and run with Hydra config ▸ Inspired by disentanglement_lib
Losses currently have to be defined at a framework level through the use of overrides.
It would be nice if losses can be built directly, for example:
beta_vae_loss = 4 * KlRegLoss() + MseRecLoss()
# or even
beta_vae_loss = Param(cfg, 'beta') * KlRegLoss() + MseRecLoss()
# computing the loss for a training step as
loss, logs = beta_vae_loss.compute_loss(ds_posterior, ds_prior, zs_sampled, xs_partial_recon, xs_targ)
dfc_vae_loss = Param(cfg, 'beta') * KlRegLoss() + 0.5 * MseRecLoss() + 0.5 * DfcRecLoss()
# computing the loss for a training step as
loss, logs = dfc_vae_loss.compute_loss(ds_posterior, ds_prior, zs_sampled, xs_partial_recon, xs_targ)
Alternative data structures should be discussed.
This would require a substantial rewrite, however existing frameworks
could define these loss components in the constructor.
Losses currently have to be defined at a framework level through the use of overrides.
It would be nice if losses can be built directly, for example:
Alternative data structures should be discussed.
This would require a substantial rewrite, however existing frameworks could define these loss components in the constructor.