Wolfda95 / SSL-MedicalImagining-CL-MAE

Self-Supervised Pre-Training with Contrastive and Masked Autoencoder Methods for Dealing with Small Datasets in Deep Learning for Medical Imaging
Apache License 2.0
11 stars 1 forks source link

Pretrain Swav #2

Open 3bao opened 6 days ago

3bao commented 6 days ago

Hello, thanks for sharing your paper and code, this is a great work! I want to run pretraining swav using my own data, and I am following your code swav_module_lidc.py. To my understanding, there is no validation dataset/dataloader, and the check points will be saved based on the training loss. Then, what would "validation_step" do during the training? Thank you very much!

Wolfda95 commented 1 day ago

Hello, I'm really glad that you're interested in our work and thank you for this great question.

You are absolutely right that in the code as I put it on GitHub, I do not use the "validation_step" during pre-training. And the checkpoints are saved based on the training loss.

The "validation_step" was originally implemented by PyTorch Lightning Bolts. At the beginning of my research in this area, I used this "validation_step" and took away parts of the pre-training dataset for a pre-training validation. However, in all my experiments, this validation loss was absolutely useless for finding out how good the pre-training would end up being on the downstream tasks, and when to stop the pre-training. So it gave me no additional information above the information I got from the pre-training loss.

So to figure out which checkpoint works well, I needed to run downstream tasks. I tried two things: (swav_module_lidc.py, line 602)

1)

# Save the model
    lr_monitor = LearningRateMonitor(logging_interval="step")
    model_checkpoint = ModelCheckpoint(filename=os.path.join(checkpoint_dir, "{epoch}-{train_loss:.2f}"),
                                       save_last=True, **save_top_k=200**,
                                       monitor="train_loss")

Here I save the best 200 checkpoints based on the training loss.

2)

# Save the model
    lr_monitor = LearningRateMonitor(logging_interval="step")
    model_checkpoint = ModelCheckpoint(filename=os.path.join(checkpoint_dir, "{epoch}-{train_loss:.2f}"),
                                       save_last=True,
                                       **every_n_epochs=50,**
                                       monitor="train_loss")

Here I save a checkpoint every 50 epochs.

So the training loss is also not really a big help in figuring out which checkpoint works best, but it at least gave some very view information. However, for my publication, I saved a Checkpoint every 50 epochs, and starting from epoch 300 I ran downstream tasks on each of the saved Checkpoints. This way I found which Checkpoint (so at which pre-training step) really works best on downstream tasks. I found that this was absolutely not correlated with a low validation loss. So I stopped using the "validation_step" part and just took all the data for pre-training without taking something away for a pre-train validation. And I also found that the correlation with the training loss is also very low. So, so far I have not found a good solution on how to find when to stop the pre-training. It seems like that really only downstream task runs show how well the pre-training really works.

However, this was all only true in my datasets. So it might be different on your dataset. And I would be really interested in your experience.

I hope this explanation helps you a little bit. This is original code from PyTorch Lightning Bolts that I modified: https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/models/self_supervised/swav/swav_module.py You can of course also experiment with the original code and see how it works for your data. They have also a nice documentation here: https://lightning-bolts.readthedocs.io/en/latest/.

Let me know if this answers your question and if you have further questions.