Closed pfeatherstone closed 12 months ago
+1!
@pfeatherstone @hugofloresgarcia you can already use null key values by setting attn_num_mem_kv = {num null k/v}
on either the Encoder
or Decoder
yup i can add it
wow, the continuous wrapper is very popular! had no idea
i think there is a bug. I'll knock up a quick repro
lm = ContinuousTransformerWrapper(
dim_in = 4,
dim_out = 256+3,
max_seq_len = 0,
num_memory_tokens = 20,
attn_layers = Decoder(
dim = 512,
depth = 4,
heads = 4,
rotary_pos_emb = True,
attn_flash = True,
use_scalenorm = True,
attn_onnxable = True,
shift_tokens = 1
)
)
x = torch.randn(2, 1024, 4)
l = torch.randint(100, x.shape[1], size=(x.shape[0],))
m = torch.arange(x.shape[1]).unsqueeze(0) < l.unsqueeze(-1)
x = lm(x, mask=m)
I'll file a new bug
@pfeatherstone oh oops, yup, should be fixed in 1.23.4
Can we support either
num_memory_tokens
or null key/value inContinuousTransformerWrapper
please?