Closed RylanSchaeffer closed 8 months ago
KNN evaluation during training is a bit tricky, the simplest or best approach depends a bit on your setup.
I use PyTorch Lightning. I was hoping to be able to use both small (e.g., CIFAR10) and large (e.g., ImageNet 1k) datasets on single and multiple-gpu runs, respectively.
I was also hoping to be able to switch between KNN and linear evaluation.
This is my current attempt (only linear evaluation for now), but I need to debug it. For some reason, gradients can't backpropagate.
class MultiViewSSLEvalCallback(lightning.Callback):
def on_validation_epoch_end(self, trainer, pl_module):
embedded_data_by_split = self.embed_data_using_backbone(
wandb_config=pl_module.wandb_config, backbone=pl_module.ssl_system.backbone
)
# Data loaders
train_loader = torch.utils.data.DataLoader(
embedded_data_by_split["train"],
batch_size=pl_module.wandb_config["finetune_batch_size"],
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=0, # Without these: RuntimeError: DataLoader worker exited unexpectedly
)
# TODO: Should this be val or test?
val_loader = torch.utils.data.DataLoader(
embedded_data_by_split["val"],
batch_size=pl_module.wandb_config["finetune_batch_size"],
shuffle=False,
drop_last=True,
pin_memory=True,
num_workers=0, # Without these: RuntimeError: DataLoader worker exited unexpectedly
)
finetune_system = src.systems.MultiViewSSLAffineClassificationEvalSystem(
feature_dim=embedded_data_by_split["train"].tensors[0].shape[1],
num_classes=embedded_data_by_split["train"].tensors[1].max().item() + 1,
max_finetune_epochs=pl_module.wandb_config["finetune_n_epochs"],
finetune_learning_rate=pl_module.wandb_config["finetune_learning_rate"],
finetune_learning_rate_scheduler=pl_module.wandb_config[
"finetune_learning_rate_scheduler"
],
finetune_weight_decay=pl_module.wandb_config["finetune_weight_decay"],
)
# For some reason, we need to place the finetune system in train mode.
finetune_system.train()
trainer = pl.Trainer(
default_root_dir=os.path.join(
pl_module.wandb_config["run_checkpoint_dir"],
"affine_classification_eval",
),
accelerator="gpu" if torch.cuda.is_available() else None,
# devices=1,
max_epochs=pl_module.wandb_config["finetune_n_epochs"],
callbacks=[
ModelCheckpoint(
save_weights_only=True, mode="max", monitor="finetune/val_acc"
),
LearningRateMonitor("epoch"),
],
logger=pl_module.wandb_logger,
# enable_progress_bar=False,
check_val_every_n_epoch=10,
# log_every_n_steps=1,
profiler="simple",
)
trainer.validate(model=finetune_system, dataloaders=val_loader)
trainer.fit(
model=finetune_system,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
)
@staticmethod
def embed_data_using_backbone(
wandb_config: Dict[str, Any], backbone: pl.LightningModule
) -> Dict[str, torch.utils.data.TensorDataset]:
print("Embedding data using backbone...")
train_dataset, _ = src.data.create_datasets(
dataset_str=wandb_config["finetune_dataset"],
split="train",
dataset_dir=wandb_config["dataset_dir"],
dataset_kwargs=wandb_config["finetune_dataset_kwargs"],
n_views=1,
sample_percent=wandb_config["finetune_dataset_sample_percent"],
seed=wandb_config["seed"],
**wandb_config["finetune_dataset_kwargs"],
)
val_dataset, _ = src.data.create_datasets(
dataset_str=wandb_config["finetune_dataset"],
split="val",
dataset_dir=wandb_config["dataset_dir"],
dataset_kwargs=wandb_config["finetune_dataset_kwargs"],
n_views=1,
sample_percent=wandb_config["finetune_dataset_sample_percent"],
seed=wandb_config["seed"],
**wandb_config["finetune_dataset_kwargs"],
)
# Prepare model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone.to(device)
with torch.no_grad():
embedded_datasets_by_split = {}
for split, dataset in [("train", train_dataset), ("val", val_dataset)]:
# Encode all images
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=64,
num_workers=2,
shuffle=False,
drop_last=False,
)
# TODO: how should augmentations be handled here?
embeddings, labels = [], []
for batch_imgs, batch_labels in tqdm.tqdm(data_loader):
batch_imgs = batch_imgs.to(device)
batch_embeddings = backbone(batch_imgs)["outputs"]
# The second dimension is the number of views, which is set to 1. Remove it.
embeddings.append(batch_embeddings[:, 0].detach().cpu())
labels.append(batch_labels)
# break # useful for fast debugging.
embeddings = torch.cat(embeddings, dim=0)
labels = torch.cat(labels, dim=0)
embedded_datasets_by_split[split] = torch.utils.data.TensorDataset(
embeddings, labels
)
print("Finished embedding data using backbone.")
return embedded_datasets_by_split
I then use it as:
finetune_eval_callback = MultiViewSSLEvalCallback()
callbacks = [
lr_monitor_callback,
checkpoint_callback,
finetune_eval_callback,
]
...
if __name__ == "__main__":
pp = pprint.PrettyPrinter(indent=4)
print("W&B Config:")
pp.pprint(wandb_config)
trainer = pl.Trainer(
accumulate_grad_batches=wandb_config["accumulate_grad_batches"],
callbacks=callbacks,
check_val_every_n_epoch=wandb_config["check_val_every_n_epoch"],
default_root_dir=run_checkpoint_dir,
deterministic=True,
accelerator="gpu",
# devices="4",
# strategy='ddp',
fast_dev_run=True,
# fast_dev_run=False,
logger=wandb_logger,
log_every_n_steps=1,
# overfit_batches=1, # useful for debugging
gradient_clip_val=wandb_config["gradient_clip_val"],
max_epochs=wandb_config["pretrain_n_epochs"],
num_sanity_val_steps=0, # -1 means runs all of validation before starting to train.
# limit_train_batches=0.01,
profiler="simple", # Simplest profiler
# profiler="advanced", # More advanced profiler
# profiler=PyTorchProfiler(filename=), # PyTorch specific profiler
precision=wandb_config["precision"],
)
# Explicitly validate before beginning training.
trainer.validate(model=pretrain_system, datamodule=datamodule)
trainer.fit(model=pretrain_system, datamodule=datamodule)
If this isn't a good approach, could you please tell me what you'd recommend?
I haven't found a nice solution for PyTorch Lightning so far, in the end we decided to run linear evaluation only at the end of training. As you noticed, the issue is that you have to create a new trainer instance inside the model code and I am not sure if this works well with PyTorch Lightning.
One thing you can do is use online linear evaluation by adding a classification layer to the SSL module and training it during pretraining. We have a module for this here: https://github.com/lightly-ai/lightly/blob/master/lightly/utils/benchmarking/online_linear_classifier.py#L10 and here is an example on how to use it: https://github.com/lightly-ai/lightly/blob/a5ef7d07a8233466c407307f010e7531c85d99b0/benchmarks/imagenet/resnet50/simclr.py#L31
we decided to run linear evaluation only at the end of training.
This seems potentially risky to me, no? I can easily imagine spending compute on pretraining only to find out at the end that the network is useless.
Thank you though for the tip of adding a classification layer and doing online training :)
I'll close this for now, let me know if you have further questions :)
Hi! I'd like to know how to pretrain a model (e.g. SimCLR) with occasional downstream evaluation (e.g. linear classification, KNN), say every 10 pretraining epochs. But I can't find documentation about how to do this. Can you please tell me how to do this, or point me towards documentation that explains how to do this?
Thank you!