Closed pfeatherstone closed 6 months ago
@pfeatherstone ah nice catch
corrected
The first fix is:
self.num_mem_kv = num_mem_kv
if num_mem_kv > 0:
self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
Not so fast
@pfeatherstone there's another issue?
weird, let me try your example, runs for me
oh i see, it is an issue with flash attention
ok, fixed for real, thanks for reporting
If you set both
attn_num_mem_kv
andattn_one_kv_head
, you get an error.Repro: