keyu-tian / SparK

[ICLR'23 Spotlight🔥] The first successful BERT/MAE-style pretraining on any convolutional network; Pytorch impl. of "Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling"
https://arxiv.org/abs/2301.03580
MIT License
1.41k stars 82 forks source link

About sparse convolution #73

Closed Itsanewday closed 6 months ago

Itsanewday commented 6 months ago

Many thanks for your easy-to-read paper and detailed codes. I have some questions here. First, in the paper Section 3.1, it mentioned that "To overcome the problems, we propose to sparsely gather all unmasked patches into a sparse image, and then use sparse convolutions to encode it." But, in the code, i can not find any sparse convolution. Actually, SparseConv2d defined in encoder.py is realized with normal convlution and element-wise multiply. Maybe the exactly training time is slightly longer than normal convolution, am i right? Second, since the droped patches are varied during training, the model actually have seen the all data. Does this mean the information leakage? Third, during finetuning, how to use the learned mask tokens or just remove them at all?

keyu-tian commented 6 months ago
  1. We used masked operators or index-select operators to simulate truly sparse conv or sparse batchnorm, etc, since we found they were faster in practice.
  2. This is an interesting perspective, and the problem could exist in all masked autoencoding algorithms. I think data augmentation (we use random flipping and cropping) can mitigate this.
  3. they are removed and we only use the pretrained encoder.
Itsanewday commented 6 months ago

Thanks for your fast and kindly reply! And, I have another question that about finetuning for semantic segmentation, should i frozen the encoder?

keyu-tian commented 6 months ago

you can refer to what we did in object detection and instance segmentation: https://github.com/keyu-tian/SparK/blob/00883b885dce68413ea048c6703a82d8a67a83b8/downstream_d2/lr_decay.py#L8.

In this function we split a resnet encoder into N chunks and scale each parameter's learning rate based on which chunks it belongs. The i-th chunk has a scaling ratio of dec ** (N-i). So deeper chunks will have larger learning rates than shallow ones. The end chunk will have 100% the learning rate. So actually we won't frozen any parameter.

Itsanewday commented 6 months ago

Thank you very much!