zhu-xlab / SSL4EO-S12

SSL4EO-S12: a large-scale dataset for self-supervised learning in Earth observation
Apache License 2.0
179 stars 17 forks source link

Reproducing the results of linear probing #14

Closed Ehzoahis closed 11 months ago

Ehzoahis commented 11 months ago

Hi authors!

Thank you for your amazing work. However, I met some challenges when trying to reproduce the results of linear probing from ViT-S/16 on EuroSAT dataset. I used the provided backbone while implementing the training code myself. The resulting accuracy is 5% lower than the provided results. I also find the learning rate varies the results largely. Therefore, I would like to ask what is the learning rate you use when performing linear probing experiments. Thanks!

Ehzoahis commented 11 months ago

Actually, I tried logistic regression on the extracted feature using ViT-S/16 and I can only reach 93.7% accuracy following the data preprocessing pipeline suggested.

wangyi111 commented 11 months ago

Hi @Ehzoahis, thanks for your question. For reproducibility below is our example for MoCo vit-s/16 on EuroSAT-MS:

# define available gpus
export CUDA_VISIBLE_DEVICES=0,1,2,3

# run script as slurm job
srun python -u linear_EU_moco_v3.py \
--data_dir /p/scratch/hai_ssl4eo/data/eurosat/tif \
--bands B13 \
--checkpoints_dir /p/project/hai_ssl4eo/wang_yi/ssl4eo-s12-dataset/src/benchmark/fullset_temp/checkpoints/moco_lc/EU_vits16 \
--backbone vit_small \
--train_frac 1.0 \
--batchsize 64 \
--lr 0.1 \
--cos \
--epochs 100 \
--num_workers 10 \
--seed 42 \
--dist_url $dist_url \
--in_size 224 \
--pretrained /p/project/hai_ssl4eo/wang_yi/ssl4eo-s12-dataset/src/benchmark/fullset_temp/checkpoints/moco/B13_vits16_224/checkpoint_0099.pth.tar \
--linear 

There're several items that might influence your results:

Maybe have a look at our provided transfer learning code and check for potential implementation differences?

Below is part of our linear probing log:

train_len: 21600 val_len: 5400
=> loading checkpoint '/p/project/hai_ssl4eo/wang_yi/ssl4eo-s12-dataset/src/benchmark/fullset_temp/checkpoints/moco/B13_vits16_224/checkpoint_0099.pth.tar'
=> loaded pre-trained model '/p/project/hai_ssl4eo/wang_yi/ssl4eo-s12-dataset/src/benchmark/fullset_temp/checkpoints/moco/B13_vits16_224/checkpoint_0099.pth.tar'
Start training...
...
[5,    20] loss: 0.164 acc: 95.547 batch_time: 0.178 data_time: 0.115 train_time: 0.024 score_time: 0.038
[5,    40] loss: 0.175 acc: 94.844 batch_time: 0.099 data_time: 0.031 train_time: 0.028 score_time: 0.040
[5,    60] loss: 0.155 acc: 95.312 batch_time: 0.111 data_time: 0.025 train_time: 0.031 score_time: 0.054
[5,    80] loss: 0.160 acc: 95.625 batch_time: 0.088 data_time: 0.018 train_time: 0.027 score_time: 0.043
Epoch 5 val_loss: 0.135 val_acc: 95.926 time: 27.68344759941101 seconds.
...
[10,    20] loss: 0.142 acc: 96.094 batch_time: 0.175 data_time: 0.091 train_time: 0.023 score_time: 0.061
[10,    40] loss: 0.141 acc: 96.172 batch_time: 0.110 data_time: 0.034 train_time: 0.026 score_time: 0.049
[10,    60] loss: 0.150 acc: 95.625 batch_time: 0.108 data_time: 0.052 train_time: 0.022 score_time: 0.035
[10,    80] loss: 0.150 acc: 95.234 batch_time: 0.089 data_time: 0.043 train_time: 0.022 score_time: 0.023
Epoch 10 val_loss: 0.106 val_acc: 96.670 time: 19.81748080253601 seconds.
...
[15,    20] loss: 0.117 acc: 96.875 batch_time: 0.178 data_time: 0.123 train_time: 0.027 score_time: 0.027
[15,    40] loss: 0.121 acc: 96.016 batch_time: 0.104 data_time: 0.052 train_time: 0.020 score_time: 0.032
[15,    60] loss: 0.143 acc: 95.078 batch_time: 0.112 data_time: 0.058 train_time: 0.025 score_time: 0.029
[15,    80] loss: 0.144 acc: 95.938 batch_time: 0.081 data_time: 0.036 train_time: 0.022 score_time: 0.023
Epoch 15 val_loss: 0.097 val_acc: 96.931 time: 19.70111346244812 seconds.
...
[20,    20] loss: 0.119 acc: 96.719 batch_time: 0.186 data_time: 0.133 train_time: 0.025 score_time: 0.028
[20,    40] loss: 0.113 acc: 96.562 batch_time: 0.143 data_time: 0.084 train_time: 0.027 score_time: 0.031
[20,    60] loss: 0.112 acc: 96.250 batch_time: 0.089 data_time: 0.009 train_time: 0.026 score_time: 0.055
[20,    80] loss: 0.109 acc: 96.953 batch_time: 0.088 data_time: 0.006 train_time: 0.022 score_time: 0.060
Epoch 20 val_loss: 0.088 val_acc: 97.414 time: 20.258484363555908 seconds.
Ehzoahis commented 11 months ago

Thanks! Also, I would like to ask how did you handle the mismatching bands between the pre-trained models and the downstream tasks like BigEarthNet and So2Sat?

wangyi111 commented 11 months ago

We simply padded all-zero channel for the downstream data. For example, we added one zero B10 band to BigEarthNet: https://github.com/zhu-xlab/SSL4EO-S12/blob/bc0454331f627c248b505423b495adc12558838d/src/benchmark/transfer_classification/linear_BE_moco.py#L334-L335

Ehzoahis commented 11 months ago

Thank you so much! That is very helpful!