keyu-tian / SparK

[ICLR'23 Spotlight🔥] The first successful BERT/MAE-style pretraining on any convolutional network; Pytorch impl. of "Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling"
https://arxiv.org/abs/2301.03580
MIT License
1.41k stars 82 forks source link

How to deal with the conv in SE module? #54

Closed RainFrost1 closed 11 months ago

WailordHe commented 11 months ago

same question

yxchng commented 11 months ago

any updates?

RainFrost1 commented 11 months ago

Looking forward your reply, any suggestions? @keyu-tian

keyu-tian commented 11 months ago

@RainFrost1 @will1973 @yxchng SE module contains two linear layers (channel-wise) and one global average pooling. No changes required to the linear layers. As for the global pooling on a masked feature map, one should calculate the mean value at unmasked positions only. To implement this you can use our function _get_active_ex_or_ii in https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py#L30 like:

class SparseGlobalAveragePooling(nn.Module):
    def forward(self, x):   # shape: BCHW
        B, C, H, W = x.shape
        unmasked_positions = _get_active_ex_or_ii(H=H, W=W, returning_active_ex=True)  # shape: B1HW
        mean = (x * unmasked_positions).sum(dim=(2,3), keepdims=True) / unmasked_positions.sum(dim=(2,3), keepdims=True)
        return mean         # shape: BC11