During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/mnt/bn/motor-nlp-team/mlx/users/zhangkaiqi.zlkqz/repo/5355/sae/k-sparse_SAE.py", line 154, in
new_hidden_state = sae.decode(latent_act, latent_index)
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/sae/sae.py", line 188, in decode
y = decoder_impl(top_indices, top_acts.to(self.dtype), self.W_dec.mT)
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/sae/utils.py", line 102, in triton_decode
return TritonDecoder.apply(top_indices, top_acts, W_dec)
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, kwargs) # type: ignore[misc]
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/sae/kernels.py", line 406, in forward
return triton_sparse_dense_matmul(
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/sae/kernels.py", line 204, in triton_sparse_dense_matmul
triton_sparse_dense_matmul_kernel[(A,)](
File "", line 41, in triton_sparse_dense_matmul_kernel
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/triton/compiler.py", line 1590, in compile
fn_cache_manager = CacheManager(make_hash(fn, kwargs))
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/triton/compiler.py", line 1500, in make_hash
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/triton/runtime/jit.py", line 333, in cache_key
dependencies_finder.visit(self.parse())
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 428, in generic_visit
self.visit(value)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/triton/runtime/jit.py", line 55, in visit_Call
func = self.visit(node.func)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/triton/runtime/jit.py", line 52, in visit_Attribute
return getattr(lhs, node.attr)
AttributeError: module 'triton.language' has no attribute 'device_assert'
when I'm calling
Sae.decode()
, it raises error:Traceback (most recent call last): File "", line 21, in triton_sparse_dense_matmul_kernel
KeyError: ('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-3d2aedeb40d6d81c66a42791e268f98b-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.int64, torch.float32, torch.float32, torch.float32, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (64, 4096), (True, True, True, True, (True, False), (False, True), (False, False), (True, False), (True, False), (False, False)))
During handling of the above exception, another exception occurred:
Traceback (most recent call last): File "/mnt/bn/motor-nlp-team/mlx/users/zhangkaiqi.zlkqz/repo/5355/sae/k-sparse_SAE.py", line 154, in
new_hidden_state = sae.decode(latent_act, latent_index)
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/sae/sae.py", line 188, in decode
y = decoder_impl(top_indices, top_acts.to(self.dtype), self.W_dec.mT)
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/sae/utils.py", line 102, in triton_decode
return TritonDecoder.apply(top_indices, top_acts, W_dec)
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, kwargs) # type: ignore[misc]
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/sae/kernels.py", line 406, in forward
return triton_sparse_dense_matmul(
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/sae/kernels.py", line 204, in triton_sparse_dense_matmul
triton_sparse_dense_matmul_kernel[(A,)](
File "", line 41, in triton_sparse_dense_matmul_kernel
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/triton/compiler.py", line 1590, in compile
fn_cache_manager = CacheManager(make_hash(fn, kwargs))
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/triton/compiler.py", line 1500, in make_hash
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/triton/runtime/jit.py", line 333, in cache_key
dependencies_finder.visit(self.parse())
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 428, in generic_visit
self.visit(value)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/triton/runtime/jit.py", line 55, in visit_Call
func = self.visit(node.func)
File "/root/anaconda3/envs/py310/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/root/anaconda3/envs/py310/lib/python3.10/site-packages/triton/runtime/jit.py", line 52, in visit_Attribute
return getattr(lhs, node.attr)
AttributeError: module 'triton.language' has no attribute 'device_assert'
It seems like a triton kernel error.