Allen0307 / AdapterBias

Code for the Findings of NAACL 2022(Long Paper): AdapterBias: Parameter-efficient Token-dependent Representation Shift for Adapters in NLP Tasks
18 stars 0 forks source link

Hello, may I ask under which module your adapterbias structure is? #4

Closed xuguangyi1999 closed 1 year ago

xuguangyi1999 commented 1 year ago

Hello, may I ask under which module your adapterbias structure is?

Allen0307 commented 1 year ago

Hi,

This is a sample code of our Adapter:

class AdapterBias(nn.Module):
    '''Implementation of Adapter with Bias Vector
    References: https://arxiv.org/abs/2205.00305.
    '''
    def __init__(self, config, dropout=0.8):
        super().__init__()
        self.adapter_vector = nn.Parameter(torch.ones((config.hidden_size), requires_grad=True))
        self.adapter_alpha = nn.Linear(config.intermediate_size, 1)

    def forward(self, hidden_states):
        return self.adapter_alpha(hidden_states) * self.adapter_vector

#insert AdapterBias at second feed-forward
class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        if config.adapter == 'adapterbias':
            self.adapter = AdapterBias(config)

        self.config = config

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        residual_states = hidden_states
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)

        if self.config.adapter == 'adapterbias':
            hidden_states = self.adapter(residual_states) + hidden_states

       hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states
xuguangyi1999 commented 1 year ago

Thank you very nuch!