nmichlo / disent

🧶 Modular VAE disentanglement framework for python built with PyTorch Lightning ▸ Including metrics and datasets ▸ With strongly supervised, weakly supervised and unsupervised methods ▸ Easily configured and run with Hydra config ▸ Inspired by disentanglement_lib
https://disent.michlo.dev
MIT License
124 stars 18 forks source link

About metrics #48

Closed Rainbow-Six66 closed 5 months ago

Rainbow-Six66 commented 6 months ago

Hi, great package!

How to understand the num_train and batch_size of the metric_dci or metric_mig? In addition, are there any examples of using the factor indicator? thank you!

Rainbow-Six66 commented 6 months ago

This is my code from disent.metrics import metric_dci, metric_mig import torch from torch.utils.data import DataLoader from disent.metrics import metric_dci, metric_mig, metric_factor_vae import torch from torch.utils.data import DataLoader from disent.dataset import DisentDataset from model.β_vae import BetaVAE_H from disent.dataset.data import DSpritesData from disent.dataset.transform import ToImgTensorF32

def train(): device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') data = DSpritesData() dataset = DisentDataset(data, transform=ToImgTensorF32(), augment=None) dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

checkpoint = torch.load('./checkpoints/beta_vae50.pt', map_location=device)
model = BetaVAE_H().to(device)
model.load_state_dict(checkpoint['model'])
model.eval()

# we cannot guarantee which device the representation is on
get_repr = lambda x: model.mlp_encoder(x.to(device))

# evaluate
return {
    **metric_dci(dataset, get_repr, num_train=20, boost_mode='sklearn'),
    **metric_mig(dataset, get_repr, num_train=20),
    **metric_factor_vae(dataset, get_repr, num_train=20),
}

a_results = train() print('beta=4: ', a_results)

nmichlo commented 6 months ago

Hi there, and thank you!

Unfortunately docs for this are sparse. I understand this is not the most ideal, would gladly accept PRs to fix this.

However, for context, the mig, dci and factor vae scores are largely based on those from https://github.com/google-research/disentanglement_lib (Default values should be similar) From what I remember without looking at the code num_train and batch_size affect the sample size of underlying data that is used to compute the metrics. Too little data and the metrics will be inaccurate, too much and processing time will be too much. Often for metrics during training I would lower these values and then do a final larger compute at the end with the default values.

nmichlo commented 6 months ago

hydra config experiments metrics: https://github.com/nmichlo/disent/tree/8f061a87076adeae8d6e5b0fa984b660cd40e026/experiment/config/metrics

actual code that selects these: https://github.com/nmichlo/disent/blob/8f061a87076adeae8d6e5b0fa984b660cd40e026/experiment/run.py#L208-L209

metric wrapper:

fast version kwargs:

NOTE: kwargs for fast versions were arbitrarily chosen. The standard versions should follow kwargs from disentanglement_lib.

NOTE: batch_size is like batch size from dataset loaders, the model is often used within these metrics and is run on the GPU if possible.