Closed Shreyas-Dongre closed 1 year ago
Check this
import torch
import torch.nn as nn
import numpy as np
from self_retention import SelfRetentionV2,RetNetRelPosV2, RMSNorm
from configuration_retnet import RetNetConfig
S = 30
B = 2
H = 8
qk_dim = 32
v_dim = 64
q = torch.randn(B,H,S,qk_dim).cuda()
k = torch.randn(B,H,S,qk_dim).cuda()
v = torch.randn(B,H,S, v_dim).cuda()
config = RetNetConfig(decoder_layers=1,
decoder_embed_dim=256,
decoder_value_embed_dim=256,
decoder_retention_heads=8,
decoder_ffn_embed_dim=128)
retnet_rel_pos = RetNetRelPosV2(config).cuda()
model = SelfRetentionV2(config)
(cos, sin), decay_system = retnet_rel_pos(S, forward_impl='parallel') # use cos, sin for REPO position embedding
parallel_output_qk_with_gk,_, parallel_cache = model(q,k,v,decay_system,mode='qk_first',normlize_for_stable=True)
Hey,
Thankyou for the quick reply. I now understood your code and the intuition behind it. Awesome stuff.
One small doubt, increasing decoder_layers
and decoder_retention_heads
will make it Multi Scale Retention, correct?
Thank you so much! :)
Hey, The parallel form of Retention, it returns two values a tuple, but in your ReadMe, in one of your examples it is mentioned that parallel retention's output is just one tensor. So I am confused as to how to get one tensor as output instead of a tuple. Thank you