Open puallee opened 1 month ago
if name == "main":
mixture_of_attn = MixtureOfAutoregressiveAttention(
dim = 512,
local_attn_window_size = 64, # local attention window size
routed_window_size = None,
num_routed_queries = 12,
num_routed_key_values = 12,
num_experts = 2,
dim_head = 64,
heads = 8
).cuda()
x = torch.randn(1, 1023, 512).cuda()
out = mixture_of_attn(x) # (1, 1023, 512)
print(out.shape)
out = F.scaled_dot_product_attention( Traceback (most recent call last): File "/home/models/mixture_of_attention/mixture_of_attention.py", line 647, in
out = mixture_of_attn(x) # (1, 1023, 512)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, *kwargs)
File "/home/models/mixture_of_attention/mixture_of_attention.py", line 563, in forward
attn_out = self.attn(
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(args, kwargs)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, kwargs)
File "/home/models/mixture_of_attention/mixture_of_attention.py", line 211, in forward
out = self.attend(q, k, v, mask = mask)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(args, kwargs)
File "/home/models/mixture_of_attention/attend.py", line 110, in forward
return self.flash_attn(q, k, v, mask = mask)
File "/home/models/mixture_of_attention/attend.py", line 87, in flash_attn
out = F.scaled_dot_product_attention(
RuntimeError: No available kernel. Aborting execution.