Closed JL-er closed 8 months ago
Hello, which triton version are you using? Can you delete the triton cache file? which should be located in ~/.triton
Hello, which triton version are you using? Can you delete the triton cache file? which should be located in
~/.triton
Thanks, the problem was solved after I updated the version
from gla import GatedLinearAttention import torch d_model = 1024 num_head = 4 use_gk = True # alpha use_gv = False # beta device = "cuda:0"
gla_layer = GatedLinearAttention(d_model, num_head, use_gk, use_gv).to(device)
bsz, seq_len, d_model = 32, 2048, 1024 x = torch.randn(bsz, seq_len, d_model).to(device) y = gla_layer(x)
assert y.shape == x.shape