cly124567 / SeA

The code of bmvc 2022 paper "SeA: Selective Attention for Fine-grained Visual Categorization".
4 stars 2 forks source link

Can you help me understand your code? #3

Open wjtan99 opened 11 months ago

wjtan99 commented 11 months ago

Thanks for sharing your work. Can you help me understand your code?

class MultiHeadSelfAttention(nn.Module): dim_in: int # input dimension dim_k: int # key and query dimension dim_v: int # value dimension numheads: int # number of heads, for each head, dim = dim_ // num_heads num_select:int def init(self, dim_in, dim_k, dim_v, select_rate,num_heads=8): super(MultiHeadSelfAttention, self).init() assert dim_k % num_heads == 0 and dim_v % num_heads == 0, "dim_k and dim_v must be multiple of num_heads" self.dim_in = dim_in self.dim_k = dim_k self.dim_v = dim_v self.num_heads = num_heads self.select_rate=select_rate self.linear_q = nn.Linear(dim_in, dim_k, bias=False) self.linear_k = nn.Linear(dim_in, dim_k, bias=False) self.linear_v = nn.Linear(dim_in, dim_v, bias=False) self._norm_fact = 1 / sqrt(dim_k // num_heads)

def forward(self, x):
    # x: tensor of shape (batch, n, dim_in)
    batch, n, dim_in = x.shape
    assert dim_in == self.dim_in

    nh = self.num_heads
    dk = self.dim_k // nh  # dim_k of each head
    dv = self.dim_v // nh  # dim_v of each head

    q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1, 2)  # (batch, nh, n, dk)
    k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1, 2)  # (batch, nh, n, dk)
    v = self.linear_v(x).reshape(batch, n, nh, dv).transpose(1, 2)  # (batch, nh, n, dv)

    dist = torch.matmul(q, k.transpose(2, 3)) * self._norm_fact  # batch, nh, n, n

    temp = dist[:,:,0,:].reshape(batch, nh, 1, n)
    index = torch.argsort(-temp, dim=-1)
    index=index[:,:,:,int(self.select_rate*n-1)].reshape(batch,nh,1,1)
    index=index.repeat(1,1,1,n)
    max = torch.take_along_dim(temp,index,dim=3)
    zero=torch.zeros(1).cuda()
    rel = torch.where(temp >= max, temp, zero)
    rel=torch.softmax(rel,dim=-1)

Why do you only use 0 at the third-dimension temp = dist[:,:,0,:].reshape(batch, nh, 1, n) ? The dist is the correlation matrix of q and k. Do you only use 0-th row in this attention?

cly124567 commented 11 months ago

Yes. 0-th row represents the correlation betweent the classification token and other regions. And matmul V get dimension [B, nh, 1, dv] only represents classification token.

wjtan99 commented 11 months ago

Thank you for your prompt reply. I set up the code and started training now. After the code works, I should be able to figure out most details. I may have more questions.

wjtan99 commented 11 months ago

I have trained 102 epochs, but the best test ACC is around 92%.
Iteration 100, test_acc = 91.93994, test_acc_combined1 = 92.02623,test_acc_combined2 = 91.97446,test_loss = 0.401251 Iteration 101, test_acc = 91.92268, test_acc_combined1 = 91.85364,test_acc_combined2 = 91.90542,test_loss = 0.385748 Iteration 102, test_acc = 91.93994, test_acc_combined1 = 92.00897,test_acc_combined2 = 91.97446,test_loss = 0.380427 Iteration 103, test_acc = 92.09527, test_acc_combined1 = 92.00897,test_acc_combined2 = 92.18157,test_loss = 0.389744 Iteration 104, test_acc = 92.07801, test_acc_combined1 = 92.00897,test_acc_combined2 = 92.06075,test_loss = 0.387197 Iteration 105, test_acc = 92.06075, test_acc_combined1 = 92.00897,test_acc_combined2 = 92.21609,test_loss = 0.389983 Iteration 106, test_acc = 92.07801, test_acc_combined1 = 91.99172,test_acc_combined2 = 92.18157,test_loss = 0.388145 Iteration 107, test_acc = 91.76735, test_acc_combined1 = 91.88816,test_acc_combined2 = 91.99172,test_loss = 0.389507 Iteration 108, test_acc = 92.06075, test_acc_combined1 = 92.16431,test_acc_combined2 = 92.09527,test_loss = 0.392833 Iteration 109, test_acc = 92.21609, test_acc_combined1 = 92.18157,test_acc_combined2 = 92.23334,test_loss = 0.392652 Iteration 110, test_acc = 92.14705, test_acc_combined1 = 92.11253,test_acc_combined2 = 92.28512,test_loss = 0.393456 Iteration 111, test_acc = 92.21609, test_acc_combined1 = 92.14705,test_acc_combined2 = 92.28512,test_loss = 0.384375 Iteration 112, test_acc = 92.02623, test_acc_combined1 = 92.11253,test_acc_combined2 = 92.07801,test_loss = 0.378721 Iteration 113, test_acc = 92.06075, test_acc_combined1 = 92.11253,test_acc_combined2 = 92.06075,test_loss = 0.394048 Iteration 114, test_acc = 92.25060, test_acc_combined1 = 92.23334,test_acc_combined2 = 92.26786,test_loss = 0.396718 Iteration 115, test_acc = 92.16431, test_acc_combined1 = 92.11253,test_acc_combined2 = 92.12979,test_loss = 0.387735

Can you share your checkpoint that can achieve 93% ACC on CUB-200-2011?
Thanks.