WillDreamer / Aurora

[NeurIPS2023] Parameter-efficient Tuning of Large-scale Multimodal Foundation Model
https://arxiv.org/abs/2305.08381
80 stars 7 forks source link

Question about the adapter. #3

Closed Midkey closed 10 months ago

Midkey commented 10 months ago

Thanks for your wonderful work!

When I want to follow your code, we can not found the code for adapter in Aurora/tree/main/CP/med.py Line 47

from .adapter import Adapter_Lora

Do we need install some packages to fix this?

xinlong-yang commented 10 months ago

Thanks for your interest. You don't need to install packages to fix it, actually, Adapter_Lora is imported to construct baseline comparison, so you can directly delete this line. If you want to do some experiments about Adapter, here is a classic implementation:

class Adapter_Lora(nn.Module):
    def __init__(self,
                 d_model=768,
                 bottleneck=64,
                 dropout=0.0,
                 init_option="lora",
                 adapter_scalar="learnable_scalar",):
        super().__init__()
        self.n_embd = d_model
        self.down_size = bottleneck

        if adapter_scalar == "learnable_scalar":
            self.scale = nn.Parameter(torch.ones(1))
        else:
            self.scale = float(adapter_scalar)

        self.down_proj = nn.Linear(self.n_embd, self.down_size)
        self.non_linear_func = nn.ReLU()
        self.up_proj = nn.Linear(self.down_size, self.n_embd)

        self.dropout = dropout

    def init_adapter_weights(self,):
        with torch.no_grad():
            nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
            nn.init.zeros_(self.up_proj.weight)
            nn.init.zeros_(self.down_proj.bias)
            nn.init.zeros_(self.up_proj.bias)

    def forward(self, x, add_residual=True):
        down = self.down_proj(x)
        # down = self.non_linear_func(down)
        up = self.up_proj(down)
        output = up * self.scale
        return output
Midkey commented 10 months ago

Thanks for your reply! I have fix this problem.