nmichlo / disent

🧶 Modular VAE disentanglement framework for python built with PyTorch Lightning ▸ Including metrics and datasets ▸ With strongly supervised, weakly supervised and unsupervised methods ▸ Easily configured and run with Hydra config ▸ Inspired by disentanglement_lib
https://disent.michlo.dev
MIT License
122 stars 18 forks source link

Modular Losses #11

Closed nmichlo closed 3 years ago

nmichlo commented 3 years ago

Losses currently have to be defined at a framework level through the use of overrides.

It would be nice if losses can be built directly, for example:

beta_vae_loss =  4 * KlRegLoss() + MseRecLoss()
# or even
beta_vae_loss =  Param(cfg, 'beta') * KlRegLoss() + MseRecLoss()

# computing the loss for a training step as
loss, logs = beta_vae_loss.compute_loss(ds_posterior, ds_prior, zs_sampled, xs_partial_recon, xs_targ)
dfc_vae_loss =  Param(cfg, 'beta') * KlRegLoss() + 0.5 * MseRecLoss() + 0.5 * DfcRecLoss()

# computing the loss for a training step as
loss, logs = dfc_vae_loss.compute_loss(ds_posterior, ds_prior, zs_sampled, xs_partial_recon, xs_targ)

Alternative data structures should be discussed.

This would require a substantial rewrite, however existing frameworks could define these loss components in the constructor.

nmichlo commented 3 years ago

Too much work to implement at the moment.