facebookresearch / suncet

Code to reproduce the results in the FAIR research papers "Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples" https://arxiv.org/abs/2104.13963 and "Supervision Accelerates Pre-training in Contrastive Semi-Supervised Learning of Visual Representations" https://arxiv.org/abs/2006.10803
MIT License
487 stars 67 forks source link

What "class_per_batch" should I set to reproduce the ImageNet 10% label performance #11

Closed Ir1d closed 3 years ago

Ir1d commented 3 years ago

Hi, thank you for sharing the code.

I'm a bit confused by the class_per_batch option here.

In Sec 5. of the paper, it said "randomly sample 6720 images, 960 classes and 7 images per class", but in the config it showed 15 classes and 7 images per class. I wonder if I need to change this 15?

In the code it showed num_support = supervised_views * s_batch_size * classes_per_batch, so the num_support we actually have is 105 = 1 * 7 * 15 on each gpu

I know 960 is 15 x 64, does it mean that the actuall num_support should be classes_per_batch * u_batch_size ?

I ran the main.py on single device with 4 gpu, but the trained performance is extremely low (less than 10% accuracy) , I'm not sure if it is related to this problem.

MidoAssran commented 3 years ago

Hi @Ir1d ,

The config is the default config you would use on 64 GPUs. Setting classes_per_batch: 15 and unique_classes_per_rank: true (see here) on 64 GPUs gives you the default 15*64 = 960 classes. In line 308 in the code, num_support is the total number of supervised images loaded on that GPU (each GPU only computes the loss on its local samples). The total number of support samples across all GPUs would equal num_support * world_size = 105 * 64 = 6720. Hope that clarifies classes_per_batch, but if not let me know!

As for training on 4 GPUs, I haven't tried that, but I'll look into establishing a small batch baseline. In general, you need to do a few things:

  1. decrease the learning-rate to account for the smaller batch-size (e.g., linear scaling or square-root scaling)
  2. you may need longer training to compensate for the smaller support set

As for the support set, we found (as per the ablation in Section 7), that decreasing supervised_imgs_per_class should not make a huge difference, and it is more important to use a larger classes_per_batch. But even having around 400 classes in the support set (the smallest we went was 448), still gives good performance after 100 epochs.

On 4GPUs, you can try setting classes_per_batch: 120, supervised_imgs_per_class: 2, unsupervised_batch_size: 16. You would need to obviously tune the learning rate, but that should fit on something like a P100 with roughly 16Gb of memory if you set use_fp16: true (see here).

Ir1d commented 3 years ago

Hi @MidoAssran Is it possible if you share the training time of an epoch on for example 8 gpu or 4 gpus (something less than 64 gpus)

MidoAssran commented 3 years ago

Hi @Ir1d, I don't have this on hand since I haven't tried something like 8 gpus yet, but I'll def look into it.

Ir1d commented 3 years ago

@MidoAssran Thank you so much. I'm having about 2h / epoch and I'm not sure if I ran it correctly.

MidoAssran commented 3 years ago

Hi @Ir1d,

Apologies for the delay getting back to you, I've had a lot on my plate recently, but I did finally try a small-batch run in the 10% label setting on ImageNet. See below (mostly reproduced from my comment here). I've attached my config file as well, but let me know if you have any trouble reproducing this. I'm going to close this issue for now.

Using 8 V100 GPUs for 100 epochs with 10% of ImageNet labels, I get

This top-1 accuracy is consistent with the ablation in the bottom row of table 4 in the paper (similar support set, but much larger batch-size).

Here is the config I used to produce this result when running on 8 GPUs. To explain some of the choices:

All other hyper-parameters are identical to the large-batch setup.

criterion:
  classes_per_batch: 70
  me_max: false
  sharpen: 0.25
  supervised_imgs_per_class: 3
  supervised_views: 1
  temperature: 0.1
  unsupervised_batch_size: 32
data:
  color_jitter_strength: 1.0
  data_seed: null
  dataset: imagenet
  image_folder: imagenet_full_size/061417/
  label_smoothing: 0.1
  multicrop: 6
  normalize: true
  root_path: datasets/
  subset_path: imagenet_subsets
  unique_classes_per_rank: true
  unlabeled_frac: 0.90
logging:
  folder: /path_to_save_models_and_logs/
  write_tag: paws
meta:
  copy_data: true
  device: cuda:0
  load_checkpoint: false
  model_name: resnet50
  output_dim: 2048
  read_checkpoint: null
  use_fp16: true
  use_pred_head: true
optimization:
  epochs: 100
  final_lr: 0.0012
  lr: 1.2
  momentum: 0.9
  nesterov: false
  start_lr: 0.3
  warmup: 10
  weight_decay: 1.0e-06