seekerhuang / HarMA

[ICLRW 2024] Efficient Remote Sensing with Harmonized Transfer Learning and Modality Alignment
https://arxiv.org/abs/2404.18253
Apache License 2.0
33 stars 1 forks source link

The calculation of the parameter #3

Closed Forrest-ht closed 2 months ago

Forrest-ht commented 2 months ago

The calculation of the parameter quantity in the paper is incorrect. The parameter quantity of the adapter in the paper is 0.5M, but I have calculated it to be 5.66M using the following program. I would like the author to check if this is a typographical error and look forward to your reply.

   import torch
   from torch import nn
   class BiShareAdapter(nn.Module):
      def init(self, hidden_dim, num_heads):
          super(BiShareAdapter, self).init()
          self.hidden_dim = hidden_dim
          self.num_heads = num_heads
          self.l1 = nn.Linear(hidden_dim, hidden_dim//2)
          self.l2 = nn.Linear(hidden_dim//2, hidden_dim)

          # Add multi-head attention
          self.multihead_attention1 = nn.MultiheadAttention(hidden_dim//2, num_heads)
          self.gate1 = nn.Parameter(torch.tensor(0.6), requires_grad=True)

  class MMadapter(nn.Module):
      def __init__(self, share_adapter, hidden_size, layer_id=0):
          super(MMadapter, self).__init__()
          self.img_proj_down = nn.Linear(hidden_size, 128)
          self.img_proj_up = nn.Linear(128, hidden_size)
          self.BiShareAdapterxx = share_adapter
          self.multihead_attention = nn.MultiheadAttention(128, 8)
          self.gate1 = nn.Parameter(torch.tensor(0.6), requires_grad=True)

  BiShareAdapter = nn.ModuleList([BiShareAdapter(128, 8) for _ in range(12)])
  MMadapter_img = nn.ModuleList([MMadapter(None,hidden_size=768,layer_id=layer_id) for layer_id in range(12)])
  MMadapter_text = nn.ModuleList([MMadapter(BiShareAdapter[layer_id],hidden_size=512,layer_id=layer_id) for layer_id in range(12)])    

  paras = (sum(p.numel() for p in MMadapter_img.parameters()) + sum(p.numel() for p in MMadapter_text.parameters()))/1024./1024.
  print(f"adapter parameters:{round(paras,2)} M")
seekerhuang commented 2 months ago

Please note the $\dag$ symbol after the parameters. For example, you can see that the Adapter in the third column only has 0.17M parameters.