Kevinz-code / CSRA

Official code of ICCV2021 paper "Residual Attention: A Simple but Effective Method for Multi-Label Recognition"
GNU Affero General Public License v3.0
213 stars 37 forks source link

MobileNet implementation of CSRA #20

Open ghylander opened 2 years ago

ghylander commented 2 years ago

I'm trying to implement CSRA using MobileNet as the backbone, but I'm running into some troubles. This is kind of related to #5. First of all, from the paper it was not clear to me whether CSRA is to be applied before, after or instead of the classifier.

Now, I have a question: Which version of MobileNet was CSRA implemented into? In my case, I'm trying to use MobileNetV3Large It's stated in the paper it's MobileNetV2

In my use case, I would like to use MobileNetV3 classifcation head, except with a different number of target classes. Where is CSRA supposed to be placed?

This is the structure of the MobileNetV3 classifier: WhatsApp Image 2022-05-19 at 10 18 23 AM

Is the CSRA supposed to replace the Avg Pool on the (7,7,960) tensor? to replace the 1x1 Conv after the (1,1,1280) tensor? To take place after the last 1x1 Conv?

I think most of the confusion comes from Fig 1 and Fig 2 in the CSRA paper.

Kevinz-code commented 2 years ago

Hi, @ghylander Thanks for your question and implementation.

Here are three key steps in CSRA module:

  1. generate the attention score s^i_j and class specific feature a^i (Eq. 2 and Eq. 3)
  2. combine each a^i with avepool feature g to get f^i
  3. sent f^i to get the i-th class logit. (Eq. 6)

Input of CSRA: feature before avgpool (B x dimention x H x W)
Ourput of CSRA: the logit (B x C) Step 1-3 can be expressed by Eq. 8 and is implemented in 'Class CSRA' in pipeline/csra.py. Can refer to #5 for our reply.

About MobileNetV3, you can apply CSRA 3 steps to replace the Avg Pool on the (7,7,960) tensor, then the final logit will be the output. There might be a little acc drop since the H-Swish struture will be discarded in this case.

For more details, refer to our paper.

Best,

ghylander commented 2 years ago

Thanks for the reply, I think it made things clearer. If I'm not mistaken, the diagram below is the full "tensor-flow" (heh, no pun intended) of the CSRA module implemented in the code and the one used in the paper:

image (made a mistake and fixed the diagram)

I do have a question regarding the code, I work mainly with tensorflow, and I'm not fully familiar with pytorch workflows/structures. What does this function do?:

score = score.flatten(2)

I can assume it flattens the dimensions of a vector, but why is this done? same applies to the transpose of the batch normalization:

torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1)

Why is the result of the normalization transposed before being applied to the output of the fully connected layer?

ghylander commented 2 years ago

As an update to this, I managed to implement almost all of CSRA in my MobileNetV3 model. Had to dive deep into both the pytorch and tensorflow docs to fully translate one into the other. Only bit I'm currently missing is the multi-head attention, which also connects with a doubt I have with the pipeline/csra.py file.

As far as I understood, using multi-head attention creates T-number of parallel heads, then CSRA is applied within each head with temperature = T. This results in T-number of vectors of shape (Batch, Classes). Then, all of these vectors are added together element by element and the resulting tensor is sent to a sigmoid activation function.

Now, in the case where num_heads = 1, the resulting tensor of the single head is also to the sigmoid activation too.

Is that correct?

ghylander commented 2 years ago

Update on this, I had to put this on hold and am returning to it now. I managed to implement the drop-in CSRA module in TensorFlow (V2.9) for the MobileNetV3-Large backbone. I half-managed to implement the trainable block.

Can you clarify some stuff for me? Some of my trouble comes from the translation from pyTorch -> TensorFlow:

1.- When declaring the CSRA class in pipeline/csra.py, line 6. In the forward() method, in line 18: what you are performing is a weight normalisation of the features vector, isn't it?

2.- Then, on line 19, you flatten the resulting vector height and width. I'm not familiar with pyTorch's flatten() method, but I assume it works just like NumPy's, by 'appending' the nested dimensions one after the other, correct?

3.- In line 20, you compute the mean of the HxW array. Looking at figure 1 in the paper, this seems to be equivalent to an average pooling operation. Is this correct?

4.- Lastly, your current implementation works with logits. What impact would it have on the CSRA performance to implement a activation function (softmax or sigmoid) to the output logits (needless to say, the loss function used would account for this)? i.e.:

score = self.head(x) / torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1)
score = score.flatten(2)
base_logit = torch.mean(score, dim=2)
score_soft = self.softmax(score * self.T)
att_logit = torch.sum(score * score_soft, dim=2)

output = torch.nn..Sigmoid()(base_logit + self.lam * att_logit)

Here's my current CSRA drop in implementation in TF v2.9:

# Defining model input tensor shape, (None) means dynamic shape
inputs = tf.keras.Input(shape=(None, None, 3))

# Base model is the backbone, features is the resulting vector with shape (batch, H, W, d)=(32, 7, 7, 960)
features = base_model(inputs, training=False)

# Applying a Fully Connected layer to the backbone output
attentions = tf.keras.layers.Conv2D(1280, kernel_size=1, padding='same', use_bias=False)(features)
# Applying Batch Normalization to the FC layer output
attentions = tf.keras.layers.BatchNormalization()(attentions)
# Applying Average Pooling to the normalized FC output
avg_attentions = tf.keras.layers.GlobalAveragePooling2D()(attentions)
# Applying Max Pooling to the normalized FC output
max_attentions = tf.keras.layers.GlobalMaxPooling2D()(attentions)
# Computing CSRA logit output with lambda = 0.2
csra_output = (avg_attentions + max_attentions*0.2)
# Applying a dropout layer
csra_output_dropout = tf.keras.layers.Dropout(0.2)(csra_output)
# Applying a 2nd FC layer with sigmoid activation
outputs_csra = tf.keras.layers.Dense(2, activation='sigmoid', kernel_initializer='random_uniform', bias_initializer='zeros')(csra_output_dropout)

model_csra = tf.keras.Model(inputs, outputs_csra)

Does this look correct to you?