facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.25k stars 332 forks source link

Obtain validation losses with SimCLR's fully unsupervised pre-training? #323

Open ycm opened 3 years ago

ycm commented 3 years ago

Question about applying fully unsupervised pre-training with VISSL

I have a dataset of images, split into two classes. My directory structure looks like:

data/
├── train/
│   ├── class0/
|   |   └── *.png
│   └── class1/
|       └── *.png
|
└── val/
    ├── class0/
    |   └── *.png
    └── class1/
        └── *.png

I'm currently running run_distributed_engines.py with a custom YAML config file, and my dataset catalog looks like this:

{"data_data": {"train": ["/data/train", "<unused>"], "val": ["/data/val", "<unused>"]}}

Inside the yaml, I set TEST_MODEL: False.

But I've noticed that the val folder (/data/val) doesn't actually get used. On Tensorboard, I do not see validation loss, only training loss.

So I have two questions:

  1. After pre-training, how do I retrieve the pretrained SimCLR weights for another downstream task? I think it is stored in checkpoint['classy_state_dict']['base_model']['model']['trunk'], but I want to make sure this is the case.
  2. Is it possible to use a validation set during pre-training? I'm confused as to how one can apply checkpoint selection or early stopping without a validation set.

I tried to following the Getting Started guide and the related Colab tutorial but those two sources seem a little too terse for my purposes.

Any help would be appreciated, thank you.

QuentinDuval commented 3 years ago

Hi @ycm,

You should indeed be able t access the SimCLR weights in checkpoint['classy_state_dict']['base_model']['model']['trunk']. This is the best option if you want to transfer to a downstream task not supported in VISSL.

If you want to try it on one of the VISSL benchmarks, you can avoid that entirely (for instance if you want to do a classification benchmark on any dataset) and just pass the following option to the configuration: config.MODEL.WEIGHTS_INIT.PARAMS_FILE=/path/to/simclr/model_final_checkpoint_phase999.torch

Here is an example for Imagenet linear classification:

python tools/run_distributed_engines.py \
      config=benchmark/linear_image_classification/imagenet1k/eval_resnet_8gpu_transfer_in1k_linear \
      config.MODEL.WEIGHTS_INIT.PARAMS_FILE=<my_weights.torch>

Regarding the validation loss during SimCLR pre-training, I never actually tried it before, and so I tried it, but unsuccessfully.

I enabled some options in the default configuration file for SimCLR (configs/config/pretrain/simclr/simclr_8node_resnet.yaml, is the one you are using?) and modified the configuration to add a test set:

TEST:
      DATA_SOURCES: [disk_folder]
      DATASET_NAMES: [imagenette_160_folder]
      BATCHSIZE_PER_REPLICA: 128
      LABEL_TYPE: sample_index    # just an implementation detail. Label isn't used
      TRANSFORMS:
        - name: ImgReplicatePil
          num_times: 2
        - name: RandomResizedCrop
          size: 128
        - name: RandomHorizontalFlip
          p: 0.5
        - name: ImgPilColorDistortion
          strength: 1.0
        - name: ImgPilGaussianBlur
          p: 0.5
          radius_min: 0.1
          radius_max: 2.0
        - name: ToTensor
        - name: Normalize
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
      COLLATE_FUNCTION: simclr_collator
      MMAP_MODE: True
      COPY_TO_LOCAL_DISK: False
      COPY_DESTINATION_DIR: /tmp/imagenette_160/
      DROP_LAST: True

As well as enabling it the test set: config.TEST_MODEL=true.

Unfortunately, although it makes SimCLR run on the validation set, it does not output any losses (I think we are supposed to use the config.METERS but in that case the meter should be the loss and I am not quite sure how to do this.

I think @prigoyal could have some insights I am missing, but I think we need to make a small fix to enable this.

@prigoyal what do you think?

prigoyal commented 3 years ago

thank you @ycm for reaching out. For the validation set during pre-training, is your use case to monitor the loss on the validation or to run some benchmark tasks (like linear evaluations) on model checkpoints after every epoch?

The latter is being worked on my @iseessel .

For the losses on the validation set, we will need to relax this condition https://github.com/facebookresearch/vissl/blob/master/vissl/hooks/log_hooks.py#L208 and possible just extend it so it works for both train and test.

jirvin16 commented 3 years ago

Thank you @QuentinDuval and @prigoyal ! I'm working with @ycm on this project, and have a follow-up question - how is the final checkpoint selected during pre-training? Is early stopping being used somehow, or is the model trained for N epochs and the last checkpoint used?

iseessel commented 3 years ago

Hi @jirvin16, we do not have early-stopping -- you are welcome to come up with a proposal to add it, if interested!

Correct, we train for OPTIMIZER.NUM_EPOCHS and use the last checkpoint.

jirvin16 commented 3 years ago

Hi @jirvin16, we do not have early-stopping -- you are welcome to come up with a proposal to add it, if interested!

Correct, we train for OPTIMIZER.NUM_EPOCHS and use the last checkpoint.

Got it, thanks. Are the validation set losses logged somehow? The reason I ask is we're hoping to get a sense of whether the model is learning something, without having to use a downstream/benchmark task. Perhaps monitoring validation loss isn't necessary for this - is there a way to sanity check the training loss? E.g. by looking at curves/values for other tasks or using some heuristic for reasonable performance. Any advice/pointers about how to do this would be really appreciated.

iseessel commented 3 years ago

Sure of course! So for monitoring training losses you have a few options:

  1. Enable tensorboard. See here
  2. You can also parse the "log.txt" that is generated and use something like this.

In terms of monitoring performance of downstream tasks, I definitely think it's a good idea to start looking at some transfer tasks sooner rather than later if you are running an expensive training. Some options:

  1. You can manually launch a benchmark task using your SSL pretraining checkpoints as they become available. We have a wide array of benchmark configs, I would recommend starting with something like a linear transfer on imagenet. See the docs on this. Here's a relevant tutorial here. You would instead be loading your own weights from your checkpoint.

  2. If you are using slurm, you can use the "Benchmark Suite Evaluator". See here, and here to get started. This will launch a slurm job that automatically evaluates checkpoints as they become available.

Hope this is helpful.

jirvin16 commented 3 years ago

Thanks for your response! We've figured out how to monitor the training losses with tensorboard. I was actually asking how to sanity check model training/performance without monitoring the downstream tasks (we're training on our own custom dataset). For example, here is our training loss curve: image

Are there any heuristics, comparisons to curves from other datasets, etc. that we can use to help better understand this curve? We're using a SimCLR model with NCE loss + temperature of 0.1.

iseessel commented 3 years ago

I don't have any SIMCLR loss curves off-hand. You could launch a training of SIMCLR on imagenet to compare the loss curves.

Some googling, also yielded some loss curves, e.g. here, or here. However, the distributions of your dataset and imagenet are likely very different, which could cause notable differences in their loss curves.

Looking at the loss curve two things jump out at me:

  1. The loss is really low, signaling to me some sort of collapse might be happening: Are you using batchnorm? Are you using LARC? What is your training setup (num_gpus, num_nodes)? What is your batch size?

  2. There are sudden increases in loss values. To rule out software issues, can you see if your training is frequently failing and resuming from checkpoints (this could explain the sudden changes in loss values).

However, I still think the most reliable way to collect signal on your model's performance is to setup a quick transfer learning task if at all possible.

CC: @prigoyal If you have anything to add.

jirvin16 commented 3 years ago

This is very helpful, thanks so much.

  1. Great observation re: low loss. I believe we should be using batchnorm and LARC. Running on 1 node and 1 GPU with a batch size of 256. Here is the config: custom_simclr_config.txt Anything fishy pop out there?
  2. Will look into this and let you know.

Thank you again for all of your help!

iseessel commented 3 years ago

Happy to help Jeremy!

Quick question are you building from source or are you using pip version of package?

From the config, it looks to me like you are using a batchsize of 64 (1 node, 1gpu, global_batch_size=64)?

SIMCLR struggles with smaller batch sizes due to the contrastive loss (there are less examples to contrast), so I would try a larger batch size if possible. Our reference config, simclr_8node_resnet uses a batch size of 4096. If you are doing this, you may also want to consider using Global/Sync Batch Norm ( see SYNC_BN_CONFIG in defaults.yaml). The original simclr paper talks about this extensively. This might also be helpful: "using a square root learning rate scaling can improve performance of ones with small batch sizes"

If you are willing to change approaches, I believe swav and moco are amenable to smaller batch sizes. We have support for these in VISSL. See moco docs and swav docs. When adapting the configs to using a smaller batch size, I would recommend reading the papers to see if there are other hypers that you need to adapt (we do take care of things like LR)

jirvin16 commented 3 years ago

We're using the pip version of the package.

I must be looking at the wrong value for the batch size. It looks like BATCHSIZE_PER_REPLICA is 64 and base_lr_batch_size is 256. We'll try increasing BATCHSIZE_PER_REPLICA as much as we can with our hardware.

I believe we should already using Global/Sync Batch Norm based on the config, as long we increase the number of GPUs (we're aiming to do this on a single node for now). Is that correct?

Thanks for the suggestions re: other approaches. We'll explore these as well.

iseessel commented 3 years ago

Yep correct no change required, it's actually already enabled (just a no-op when using 1gpu).

Pedrexus commented 3 years ago

thank you @ycm for reaching out. For the validation set during pre-training, is your use case to monitor the loss on the validation or to run some benchmark tasks (like linear evaluations) on model checkpoints after every epoch?

The latter is being worked on my @iseessel .

For the losses on the validation set, we will need to relax this condition https://github.com/facebookresearch/vissl/blob/master/vissl/hooks/log_hooks.py#L208 and possible just extend it so it works for both train and test.

Hello, folks. Is there an update on logging validation losses?