Open 3bao opened 6 days ago
Hello, I'm really glad that you're interested in our work and thank you for this great question.
You are absolutely right that in the code as I put it on GitHub, I do not use the "validation_step" during pre-training. And the checkpoints are saved based on the training loss.
The "validation_step" was originally implemented by PyTorch Lightning Bolts. At the beginning of my research in this area, I used this "validation_step" and took away parts of the pre-training dataset for a pre-training validation. However, in all my experiments, this validation loss was absolutely useless for finding out how good the pre-training would end up being on the downstream tasks, and when to stop the pre-training. So it gave me no additional information above the information I got from the pre-training loss.
So to figure out which checkpoint works well, I needed to run downstream tasks. I tried two things: (swav_module_lidc.py, line 602)
1)
# Save the model
lr_monitor = LearningRateMonitor(logging_interval="step")
model_checkpoint = ModelCheckpoint(filename=os.path.join(checkpoint_dir, "{epoch}-{train_loss:.2f}"),
save_last=True, **save_top_k=200**,
monitor="train_loss")
Here I save the best 200 checkpoints based on the training loss.
2)
# Save the model
lr_monitor = LearningRateMonitor(logging_interval="step")
model_checkpoint = ModelCheckpoint(filename=os.path.join(checkpoint_dir, "{epoch}-{train_loss:.2f}"),
save_last=True,
**every_n_epochs=50,**
monitor="train_loss")
Here I save a checkpoint every 50 epochs.
So the training loss is also not really a big help in figuring out which checkpoint works best, but it at least gave some very view information. However, for my publication, I saved a Checkpoint every 50 epochs, and starting from epoch 300 I ran downstream tasks on each of the saved Checkpoints. This way I found which Checkpoint (so at which pre-training step) really works best on downstream tasks. I found that this was absolutely not correlated with a low validation loss. So I stopped using the "validation_step" part and just took all the data for pre-training without taking something away for a pre-train validation. And I also found that the correlation with the training loss is also very low. So, so far I have not found a good solution on how to find when to stop the pre-training. It seems like that really only downstream task runs show how well the pre-training really works.
However, this was all only true in my datasets. So it might be different on your dataset. And I would be really interested in your experience.
I hope this explanation helps you a little bit. This is original code from PyTorch Lightning Bolts that I modified: https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/models/self_supervised/swav/swav_module.py You can of course also experiment with the original code and see how it works for your data. They have also a nice documentation here: https://lightning-bolts.readthedocs.io/en/latest/.
Let me know if this answers your question and if you have further questions.
Hello, thanks for sharing your paper and code, this is a great work! I want to run pretraining swav using my own data, and I am following your code swav_module_lidc.py. To my understanding, there is no validation dataset/dataloader, and the check points will be saved based on the training loss. Then, what would "validation_step" do during the training? Thank you very much!