Closed Shreyas-Dongre closed 1 year ago
yes, I will refer you to huggingface transformer compatible implementation of Retention Networks
Notice, I intergrate the group_norm into the self_attention
, you may modifiy a bit like
class MultiScaleRetention(nn.Module):
def __init__(
self,
config: RetNetConfig,
gate_fn="swish",
use_bias=False,
tensor_parallel=False,
):
super().__init__()
self.config = config
self.embed_dim = config.decoder_embed_dim
self.value_dim = config.decoder_value_embed_dim
self.num_heads = config.decoder_retention_heads
self.head_dim = self.value_dim // self.num_heads
self.key_dim = self.embed_dim // self.num_heads
self.scaling = self.key_dim**-0.5
self.gate_fn = get_activation_fn(activation=str(gate_fn))
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias)
self.v_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias)
self.g_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias)
self.out_proj = nn.Linear(self.value_dim, self.embed_dim, bias=use_bias)
self.self_retention= SelfRetention(config)
self.reset_parameters()
assert not tensor_parallel
#self.decay_proj = nn.Linear(self.num_heads, self.num_heads, bias=False) if tensor_parallel else None
def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-2.5)
nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-2.5)
nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-2.5)
nn.init.xavier_uniform_(self.g_proj.weight, gain=2**-2.5)
nn.init.xavier_uniform_(self.out_proj.weight)
def forward(
self,
hidden_states: torch.Tensor,
rel_pos: Tuple[Tuple[torch.Tensor]],
retention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
forward_impl: str = 'parallel',
output_retentions: Optional[bool] = False,
output_increment: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
B, T, H = hidden_states.size()
(sin, cos), decay_mask = rel_pos
# projections
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
g = self.g_proj(hidden_states)
# multi-head
q, k, v = split_heads((q, k, v), B, T, self.num_heads)
k = k*self.scaling # for scaled dot product
# rotate
# NOTE: theta_shift has bug with mps device.
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
retention_out, retention_weights, curr_kv,increment = self.self_retention(qr, kr, v, decay_mask,
past_key_value=past_key_value,
retention_mask=retention_mask,
forward_impl = forward_impl,output_increment=output_increment)
# concaat heads
# normed = self.group_norm(retention_out).reshape(B, T, self.value_dim)
# ## <--- it is better move the groupnorm into the function, thus the result obtain from different method will be same.
# ## otherwise, only the recurrent and parallel is same, but chunkwise is wrong.
# out gate & proj
out = self.gate_fn(g) * retention_out.reshape(B, T, self.value_dim)
out = self.out_proj(out)
outputs = (out, curr_kv, retention_weights, increment)
return outputs
Or you can check this repo https://github.com/veya2ztn/RetNet
Hey, Thankyou so much! It worked. Is there any way I could contact you? Email or something? Regards, Shreyas
Hey, Does the implementation support Multiscale Retention in parallel mode? I did see that multiple heads are a input hyper parameter but am not able to understand if MSR is completely implemented? The output returned by 'SelfRetentionV2' -
self.group_norm(o), None, cache
is that the output of MSR? (given that number of heads is > 1)