veya2ztn / fast_retention

Speed up Parallel Retention about 2x times
2 stars 1 forks source link

How to use Parallel Retention? #4

Closed Shreyas-Dongre closed 1 year ago

Shreyas-Dongre commented 1 year ago

Hey, The parallel form of Retention, it returns two values a tuple, but in your ReadMe, in one of your examples it is mentioned that parallel retention's output is just one tensor. So I am confused as to how to get one tensor as output instead of a tuple. Thank you

veya2ztn commented 1 year ago

Check this


import torch
import torch.nn as nn
import numpy as np
from self_retention import SelfRetentionV2,RetNetRelPosV2, RMSNorm
from configuration_retnet import RetNetConfig
S = 30
B = 2
H = 8
qk_dim = 32
v_dim  = 64
q = torch.randn(B,H,S,qk_dim).cuda()
k = torch.randn(B,H,S,qk_dim).cuda()
v = torch.randn(B,H,S, v_dim).cuda()

config = RetNetConfig(decoder_layers=1,
                    decoder_embed_dim=256,
                    decoder_value_embed_dim=256,
                    decoder_retention_heads=8,
                    decoder_ffn_embed_dim=128)
retnet_rel_pos = RetNetRelPosV2(config).cuda()
model          = SelfRetentionV2(config)

(cos, sin), decay_system = retnet_rel_pos(S, forward_impl='parallel') # use cos, sin for REPO position embedding
parallel_output_qk_with_gk,_, parallel_cache =  model(q,k,v,decay_system,mode='qk_first',normlize_for_stable=True)
Shreyas-Dongre commented 1 year ago

Hey, Thankyou for the quick reply. I now understood your code and the intuition behind it. Awesome stuff. One small doubt, increasing decoder_layers and decoder_retention_heads will make it Multi Scale Retention, correct?

Thank you so much! :)