keyu-tian / SparK

[ICLR'23 Spotlight🔥] The first successful BERT/MAE-style pretraining on any convolutional network; Pytorch impl. of "Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling"
https://arxiv.org/abs/2301.03580
MIT License
1.41k stars 82 forks source link

SparseBatchNorm2d can not mask correctly ? #78

Closed FengheTan9 closed 4 months ago

FengheTan9 commented 6 months ago

hello, i visualize the SparseConv2d and SparseBatchNorm2d find SparseBatchNorm2d can not mask correctly, could you give some solution or support ?

image

image

keyu-tian commented 6 months ago

(sparse_bn(sparse_conv(raw_inp)) > 0) * 255 should be changed to like (sparse_bn(sparse_conv(raw_inp)).abs() > 1e-5) * 255. Because non-masked features after BN can <0. Also, some non-zero non-masked values can be 0 after BN (normalized to 0, not masked to 0).

FengheTan9 commented 6 months ago

thank you for your reply. I have another question, when implementing InstanceNorm2D, is this the correct definition? 69eb760e016b0ee011bb76a5bf3af81

keyu-tian commented 6 months ago

I'm not quite familiar with how nn.InstanceNorm1d/2d works, but I'm sure what InstanceNorm1d is to InstanceNorm2d is different from what BatchNorm1d is to BatchNorm2d. So I don't think directly replacing BN with IN is correct (like above).

I suggest you to first implement an InstanceNorm without calling nn.InstanceNorm and make sure it is identical to nn.InstanceNorm. Then try to calculate the statistics (mean, std) on those non-masked positions only (maybe re-use our _get_active_ex_or_ii, or maybe need to implement a new similar function).