nmichlo / disent

🧶 Modular VAE disentanglement framework for python built with PyTorch Lightning ▸ Including metrics and datasets ▸ With strongly supervised, weakly supervised and unsupervised methods ▸ Easily configured and run with Hydra config ▸ Inspired by disentanglement_lib
https://disent.michlo.dev
MIT License
122 stars 18 forks source link

[Q]: What's the native way to split datasets into train, validation and test? #22

Closed akekic closed 2 years ago

akekic commented 2 years ago

I'm trying to train a vae on Cars3dData and I was wondering how to split an instance of DisentDataset. Is there a dedicated sampler that does this?

Here is the backbone of what I am trying to run:

import pytorch_lightning as pl
import torch

from torch.utils.data import DataLoader

from disent.dataset import DisentDataset
from disent.dataset.data import Cars3dData
from disent.frameworks.ae import Ae
from disent.metrics import metric_dci, metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.util import is_test_run  # you can ignore and remove this

# prepare the data
data = Cars3dData()
size = 64
vis_mean = [0.8976676149976628, 0.8891658020067508, 0.885147515814868]
vis_std = [0.22503195531503034, 0.2399461278981261, 0.24792106319684404]
dataset_train = DisentDataset(data, transform=ToImgTensorF32(size=64, mean=vis_mean, std=vis_std))
# dataset_val = ?
# dataset_test = ?

dataloader_train = DataLoader(
    dataset=dataset_train,
    batch_size=4,
    shuffle=True,
    num_workers=0,
)

# create the pytorch lightning system
module: pl.LightningModule = Ae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=(3, 64, 64), z_size=6),
        decoder=DecoderConv64(x_shape=(3, 64, 64), z_size=6),
    ),
    cfg=Ae.cfg(
        optimizer="adam", optimizer_kwargs=dict(lr=1e-3), loss_reduction="mean_sum"
    ),
)

# train the model
trainer = pl.Trainer(
    max_steps=10,
    checkpoint_callback=False,
    fast_dev_run=is_test_run(),
    gpus=1 if torch.cuda.is_available() else None,
)
trainer.fit(module, dataloader_train)

# compute disentanglement metrics
# - we cannot guarantee which device the representation is on
# - this will take a while to run
get_repr = lambda x: module.encode(x.to(module.device))

metrics = {
    **metric_dci(
        dataset_train, get_repr, num_train=1000, num_test=500, show_progress=True
    ),
    **metric_mig(dataset_train, get_repr, num_train=2000),
}

# evaluate
print("metrics:", metrics)

Any hints are highly appreciated. Thank you for providing this package!

Best regards Armin

nmichlo commented 2 years ago

Hi there, sorry for the delayed response!

Unfortunately this is something that I will need to add to the roadmap. I am just not entirely sure myself how to approach this problem when it comes to samplers/metrics that require ground-truth datasets.

A workaround for your current code may be:

import math

import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

from disent.dataset import DisentDataset
from disent.dataset.data import Cars3dData
from disent.dataset.transform import ToImgTensorF32
from disent.frameworks.ae import Ae
from disent.metrics import metric_dci
from disent.metrics import metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64
from disent.model.ae import EncoderConv64

# normalise the data
vis_mean = [0.8976676149976628, 0.8891658020067508, 0.885147515814868]
vis_std = [0.22503195531503034, 0.2399461278981261, 0.24792106319684404]
data = Cars3dData(transform=ToImgTensorF32(size=64, mean=vis_mean, std=vis_std))

# SOLUTION:
# -- split the data using built-in functions (no longer ground-truth datasets, but subsets)
data_train, data_val, data_test = torch.utils.data.random_split(data, [
    int(math.floor(len(data)*0.6)),
    int(math.ceil(len(data)*0.2)),
    int(math.ceil(len(data)*0.2)),
])
# -- create multiple disent datasets
dataset_train = DisentDataset(data_train)
dataset_val   = DisentDataset(data_val)
dataset_test  = DisentDataset(data_test)
# -- create dataloaders
dataloader_train = DataLoader(dataset=dataset_train, batch_size=4, shuffle=True, num_workers=0)
dataloader_val   = DataLoader(dataset=dataset_val, batch_size=4, shuffle=True, num_workers=0)
dataloader_test  = DataLoader(dataset=dataset_test, batch_size=4, shuffle=True, num_workers=0)

# create the pytorch lightning system
module: pl.LightningModule = Ae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=(3, 64, 64), z_size=6),
        decoder=DecoderConv64(x_shape=(3, 64, 64), z_size=6),
    ),
    cfg=Ae.cfg(
        optimizer="adam", optimizer_kwargs=dict(lr=1e-3)
    ),
)

# PROBLEM: unfortunately the framework does not yet implement the pytorch-lightning validation step
#          I'll add this to the roadmap and this should work in future.
trainer = pl.Trainer(max_steps=10000, checkpoint_callback=False, gpus=1 if torch.cuda.is_available() else None)
trainer.fit(module, dataloader_train, dataloader_val)

# PROBLEM: unfortunately the metrics will no longer work with the subsets
#          of data. You could instead pass the original full dataset to the
#          metrics, but this may be considered an information leak?
#          -- This will crash!
get_repr = lambda x: module.encode(x.to(module.device))
metrics = {
    **metric_dci(dataset_test, get_repr, num_train=1000, num_test=500, show_progress=True),
    **metric_mig(dataset_test, get_repr, num_train=2000),
}
print("metrics:", metrics)
nmichlo commented 2 years ago

This has been fixed in 5695747c1e94420c024f1505d9b8a4b3c81ad610 release v0.3.4

Frameworks now support basic validation and testing, reusing the code from the training step, however schedules might be broken if these are used.

A new example has been added to the docs: https://github.com/nmichlo/disent/blob/5695747c1e94420c024f1505d9b8a4b3c81ad610/docs/examples/overview_framework_train_val.py

ema-marconato commented 1 year ago

It seems not possible to use the train/val/test partition for AdaVAE training. Any way out?

nmichlo commented 1 year ago

@ema-marconato So there are different sampling strategies in the original paper that can be used in different cases.

Unfortunately only the fully random sampling strategies work with the training and validation splits.

from disent.dataset.sampling import RandomSampler

The other strategies need more information and use the ground-truth factor information to enforce certain characteristics:

from disent.dataset.sampling import GroundTruthPairOrigSampler
from disent.dataset.sampling import GroundTruthPairSampler

It is possible that a random sampler could be written that tries to enforce the constraints provided by these ground-truth samplers. Unfortunately these are not implemented.