veya2ztn / fast_retention

Speed up Parallel Retention about 2x times
2 stars 1 forks source link

Does this implementation also support Multiscale Retention? #5

Closed Shreyas-Dongre closed 1 year ago

Shreyas-Dongre commented 1 year ago

Hey, Does the implementation support Multiscale Retention in parallel mode? I did see that multiple heads are a input hyper parameter but am not able to understand if MSR is completely implemented? The output returned by 'SelfRetentionV2' - self.group_norm(o), None, cache is that the output of MSR? (given that number of heads is > 1)

veya2ztn commented 1 year ago

yes, I will refer you to huggingface transformer compatible implementation of Retention Networks Notice, I intergrate the group_norm into the self_attention, you may modifiy a bit like

class MultiScaleRetention(nn.Module):

    def __init__(
        self,
        config: RetNetConfig,
        gate_fn="swish",
        use_bias=False,
        tensor_parallel=False,

    ):
        super().__init__()
        self.config = config

        self.embed_dim = config.decoder_embed_dim
        self.value_dim = config.decoder_value_embed_dim
        self.num_heads = config.decoder_retention_heads
        self.head_dim = self.value_dim // self.num_heads
        self.key_dim = self.embed_dim // self.num_heads
        self.scaling = self.key_dim**-0.5

        self.gate_fn = get_activation_fn(activation=str(gate_fn))

        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias)
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias)
        self.v_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias)
        self.g_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias)

        self.out_proj = nn.Linear(self.value_dim, self.embed_dim, bias=use_bias)
        self.self_retention= SelfRetention(config)
        self.reset_parameters()

        assert not tensor_parallel
        #self.decay_proj = nn.Linear(self.num_heads, self.num_heads, bias=False) if tensor_parallel else None

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-2.5)
        nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-2.5)
        nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-2.5)
        nn.init.xavier_uniform_(self.g_proj.weight, gain=2**-2.5)
        nn.init.xavier_uniform_(self.out_proj.weight)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rel_pos: Tuple[Tuple[torch.Tensor]],
        retention_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        forward_impl: str = 'parallel',
        output_retentions: Optional[bool] = False,
        output_increment: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
        B, T, H = hidden_states.size()

        (sin, cos), decay_mask = rel_pos

        # projections
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)
        g = self.g_proj(hidden_states)

        # multi-head
        q, k, v = split_heads((q, k, v), B, T, self.num_heads)

        k = k*self.scaling  # for scaled dot product
        # rotate
        # NOTE: theta_shift has bug with mps device.
        qr = theta_shift(q, sin, cos)
        kr = theta_shift(k, sin, cos)

        retention_out, retention_weights, curr_kv,increment  = self.self_retention(qr, kr, v, decay_mask,
                        past_key_value=past_key_value, 
                        retention_mask=retention_mask,
                        forward_impl = forward_impl,output_increment=output_increment)

        # concaat heads
        # normed = self.group_norm(retention_out).reshape(B, T, self.value_dim) 
        # ## <--- it is better move the groupnorm into the function, thus the result obtain from different method will be same.
        # ##      otherwise, only the recurrent and parallel is same, but chunkwise is wrong.
        # out gate & proj
        out = self.gate_fn(g) * retention_out.reshape(B, T, self.value_dim)
        out = self.out_proj(out)

        outputs = (out, curr_kv, retention_weights, increment)

        return outputs
veya2ztn commented 1 year ago

Or you can check this repo https://github.com/veya2ztn/RetNet

Shreyas-Dongre commented 1 year ago

Hey, Thankyou so much! It worked. Is there any way I could contact you? Email or something? Regards, Shreyas

veya2ztn commented 1 year ago

zhangtianning110@gmail.com