lucidrains / mixture-of-attention

Some personal experiments around routing tokens to different autoregressive attention, akin to mixture-of-experts
MIT License
109 stars 3 forks source link

Can not support cuda #2

Open puallee opened 1 month ago

puallee commented 1 month ago

out = F.scaled_dot_product_attention( Traceback (most recent call last): File "/home/models/mixture_of_attention/mixture_of_attention.py", line 647, in out = mixture_of_attn(x) # (1, 1023, 512) File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, *kwargs) File "/home/models/mixture_of_attention/mixture_of_attention.py", line 563, in forward attn_out = self.attn( File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, kwargs) File "/home/models/mixture_of_attention/mixture_of_attention.py", line 211, in forward out = self.attend(q, k, v, mask = mask) File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(args, kwargs) File "/home/models/mixture_of_attention/attend.py", line 110, in forward return self.flash_attn(q, k, v, mask = mask) File "/home/models/mixture_of_attention/attend.py", line 87, in flash_attn out = F.scaled_dot_product_attention( RuntimeError: No available kernel. Aborting execution.

puallee commented 1 month ago

if name == "main":

Mixture-of-Attention

    mixture_of_attn = MixtureOfAutoregressiveAttention(
        dim = 512,
        local_attn_window_size = 64,       # local attention window size
        routed_window_size = None,         
        num_routed_queries = 12,
        num_routed_key_values = 12,
        num_experts = 2,
        dim_head = 64,
        heads = 8
    ).cuda()
    x = torch.randn(1, 1023, 512).cuda()
    out = mixture_of_attn(x) # (1, 1023, 512)
    print(out.shape)