lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
2.93k stars 253 forks source link

Intermittent downstream evaluation during pretraining #1420

Closed RylanSchaeffer closed 8 months ago

RylanSchaeffer commented 10 months ago

Hi! I'd like to know how to pretrain a model (e.g. SimCLR) with occasional downstream evaluation (e.g. linear classification, KNN), say every 10 pretraining epochs. But I can't find documentation about how to do this. Can you please tell me how to do this, or point me towards documentation that explains how to do this?

Thank you!

guarin commented 10 months ago

KNN evaluation during training is a bit tricky, the simplest or best approach depends a bit on your setup.

RylanSchaeffer commented 10 months ago

I use PyTorch Lightning. I was hoping to be able to use both small (e.g., CIFAR10) and large (e.g., ImageNet 1k) datasets on single and multiple-gpu runs, respectively.

I was also hoping to be able to switch between KNN and linear evaluation.

This is my current attempt (only linear evaluation for now), but I need to debug it. For some reason, gradients can't backpropagate.

class MultiViewSSLEvalCallback(lightning.Callback):
    def on_validation_epoch_end(self, trainer, pl_module):
        embedded_data_by_split = self.embed_data_using_backbone(
            wandb_config=pl_module.wandb_config, backbone=pl_module.ssl_system.backbone
        )

        # Data loaders
        train_loader = torch.utils.data.DataLoader(
            embedded_data_by_split["train"],
            batch_size=pl_module.wandb_config["finetune_batch_size"],
            shuffle=True,
            drop_last=True,
            pin_memory=True,
            num_workers=0,  # Without these: RuntimeError: DataLoader worker exited unexpectedly
        )

        # TODO: Should this be val or test?
        val_loader = torch.utils.data.DataLoader(
            embedded_data_by_split["val"],
            batch_size=pl_module.wandb_config["finetune_batch_size"],
            shuffle=False,
            drop_last=True,
            pin_memory=True,
            num_workers=0,  # Without these: RuntimeError: DataLoader worker exited unexpectedly
        )

        finetune_system = src.systems.MultiViewSSLAffineClassificationEvalSystem(
            feature_dim=embedded_data_by_split["train"].tensors[0].shape[1],
            num_classes=embedded_data_by_split["train"].tensors[1].max().item() + 1,
            max_finetune_epochs=pl_module.wandb_config["finetune_n_epochs"],
            finetune_learning_rate=pl_module.wandb_config["finetune_learning_rate"],
            finetune_learning_rate_scheduler=pl_module.wandb_config[
                "finetune_learning_rate_scheduler"
            ],
            finetune_weight_decay=pl_module.wandb_config["finetune_weight_decay"],
        )
        # For some reason, we need to place the finetune system in train mode.
        finetune_system.train()

        trainer = pl.Trainer(
            default_root_dir=os.path.join(
                pl_module.wandb_config["run_checkpoint_dir"],
                "affine_classification_eval",
            ),
            accelerator="gpu" if torch.cuda.is_available() else None,
            # devices=1,
            max_epochs=pl_module.wandb_config["finetune_n_epochs"],
            callbacks=[
                ModelCheckpoint(
                    save_weights_only=True, mode="max", monitor="finetune/val_acc"
                ),
                LearningRateMonitor("epoch"),
            ],
            logger=pl_module.wandb_logger,
            # enable_progress_bar=False,
            check_val_every_n_epoch=10,
            # log_every_n_steps=1,
            profiler="simple",
        )

        trainer.validate(model=finetune_system, dataloaders=val_loader)
        trainer.fit(
            model=finetune_system,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader,
        )

    @staticmethod
    def embed_data_using_backbone(
        wandb_config: Dict[str, Any], backbone: pl.LightningModule
    ) -> Dict[str, torch.utils.data.TensorDataset]:
        print("Embedding data using backbone...")
        train_dataset, _ = src.data.create_datasets(
            dataset_str=wandb_config["finetune_dataset"],
            split="train",
            dataset_dir=wandb_config["dataset_dir"],
            dataset_kwargs=wandb_config["finetune_dataset_kwargs"],
            n_views=1,
            sample_percent=wandb_config["finetune_dataset_sample_percent"],
            seed=wandb_config["seed"],
            **wandb_config["finetune_dataset_kwargs"],
        )

        val_dataset, _ = src.data.create_datasets(
            dataset_str=wandb_config["finetune_dataset"],
            split="val",
            dataset_dir=wandb_config["dataset_dir"],
            dataset_kwargs=wandb_config["finetune_dataset_kwargs"],
            n_views=1,
            sample_percent=wandb_config["finetune_dataset_sample_percent"],
            seed=wandb_config["seed"],
            **wandb_config["finetune_dataset_kwargs"],
        )

        # Prepare model
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        backbone.to(device)

        with torch.no_grad():
            embedded_datasets_by_split = {}
            for split, dataset in [("train", train_dataset), ("val", val_dataset)]:
                # Encode all images
                data_loader = torch.utils.data.DataLoader(
                    dataset,
                    batch_size=64,
                    num_workers=2,
                    shuffle=False,
                    drop_last=False,
                )
                # TODO: how should augmentations be handled here?
                embeddings, labels = [], []
                for batch_imgs, batch_labels in tqdm.tqdm(data_loader):
                    batch_imgs = batch_imgs.to(device)
                    batch_embeddings = backbone(batch_imgs)["outputs"]
                    # The second dimension is the number of views, which is set to 1. Remove it.
                    embeddings.append(batch_embeddings[:, 0].detach().cpu())
                    labels.append(batch_labels)
                    # break  # useful for fast debugging.

                embeddings = torch.cat(embeddings, dim=0)
                labels = torch.cat(labels, dim=0)

                embedded_datasets_by_split[split] = torch.utils.data.TensorDataset(
                    embeddings, labels
                )

        print("Finished embedding data using backbone.")
        return embedded_datasets_by_split

I then use it as:

finetune_eval_callback = MultiViewSSLEvalCallback()
callbacks = [
    lr_monitor_callback,
    checkpoint_callback,
    finetune_eval_callback,
]

...
if __name__ == "__main__":
    pp = pprint.PrettyPrinter(indent=4)
    print("W&B Config:")
    pp.pprint(wandb_config)

    trainer = pl.Trainer(
        accumulate_grad_batches=wandb_config["accumulate_grad_batches"],
        callbacks=callbacks,
        check_val_every_n_epoch=wandb_config["check_val_every_n_epoch"],
        default_root_dir=run_checkpoint_dir,
        deterministic=True,
        accelerator="gpu",
        # devices="4",
        # strategy='ddp',
        fast_dev_run=True,
        # fast_dev_run=False,
        logger=wandb_logger,
        log_every_n_steps=1,
        # overfit_batches=1,  # useful for debugging
        gradient_clip_val=wandb_config["gradient_clip_val"],
        max_epochs=wandb_config["pretrain_n_epochs"],
        num_sanity_val_steps=0,  # -1 means runs all of validation before starting to train.
        # limit_train_batches=0.01,
        profiler="simple",  # Simplest profiler
        # profiler="advanced",  # More advanced profiler
        # profiler=PyTorchProfiler(filename=),  # PyTorch specific profiler
        precision=wandb_config["precision"],
    )

    # Explicitly validate before beginning training.
    trainer.validate(model=pretrain_system, datamodule=datamodule)

    trainer.fit(model=pretrain_system, datamodule=datamodule)

If this isn't a good approach, could you please tell me what you'd recommend?

guarin commented 10 months ago

I haven't found a nice solution for PyTorch Lightning so far, in the end we decided to run linear evaluation only at the end of training. As you noticed, the issue is that you have to create a new trainer instance inside the model code and I am not sure if this works well with PyTorch Lightning.

One thing you can do is use online linear evaluation by adding a classification layer to the SSL module and training it during pretraining. We have a module for this here: https://github.com/lightly-ai/lightly/blob/master/lightly/utils/benchmarking/online_linear_classifier.py#L10 and here is an example on how to use it: https://github.com/lightly-ai/lightly/blob/a5ef7d07a8233466c407307f010e7531c85d99b0/benchmarks/imagenet/resnet50/simclr.py#L31

RylanSchaeffer commented 10 months ago

we decided to run linear evaluation only at the end of training.

This seems potentially risky to me, no? I can easily imagine spending compute on pretraining only to find out at the end that the network is useless.

Thank you though for the tip of adding a classification layer and doing online training :)

guarin commented 8 months ago

I'll close this for now, let me know if you have further questions :)