OpenNLPLab / lightning-attention

Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models
MIT License
182 stars 15 forks source link

TypeError("unhashable type: 'tensor'") #5

Closed caihaihua057200 closed 4 months ago

caihaihua057200 commented 7 months ago

Traceback (most recent call last): File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1124, in ast_to_ttir generator.visit(fn.parse()) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit ret = super().visit(node) File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit return visitor(node) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 293, in visit_Module ast.NodeVisitor.generic_visit(self, node) File "/share/database/code/esmfold/lib/python3.7/ast.py", line 279, in generic_visit self.visit(item) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit ret = super().visit(node) File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit return visitor(node) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 362, in visit_FunctionDef self.visit_compound_statement(node.body) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 288, in visit_compound_statement ret_type = self.visit(stmt) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit ret = super().visit(node) File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit return visitor(node) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 414, in visit_Assign values = self.visit(node.value) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit ret = super().visit(node) File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit return visitor(node) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 934, in visit_Call args = [self.visit(arg) for arg in node.args] File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 934, in args = [self.visit(arg) for arg in node.args] File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit ret = super().visit(node) File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit return visitor(node) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 462, in visit_BinOp lhs = self.visit(node.left) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit ret = super().visit(node) File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit return visitor(node) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 650, in visit_UnaryOp op = self.visit(node.operand) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit ret = super().visit(node) File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit return visitor(node) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 929, in visit_Call static_implementation = self.statically_implemented_functions.get(fn) TypeError: unhashable type: 'tensor'

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "L_attention.py", line 22, in o = lightning_attn_func(q, k, v, s) File "/share/database/code/esmfold/lib/python3.7/site-packages/lightning_attn/ops/lightning_attn_interface.py", line 34, in lightning_attn_func o = lightning_attn2(q, k, v, s) File "/share/database/code/esmfold/lib/python3.7/site-packages/lightning_attn/ops/triton/lightning_attn2.py", line 429, in forward BLOCK_MODEL=BLOCK_MODEL, File "", line 63, in _fwd_kernel File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/compiler.py", line 476, in compile next_module = compile_kernel(module) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/compiler.py", line 381, in lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch)) File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1133, in ast_to_ttir raise CompilationError(fn.src, node, repr(e)) from e triton.compiler.errors.CompilationError: at 39:22: K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None] V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] S_block_ptr = S + off_h

##### init diag decay(Lambda); q, k decay; kv
s = tl.load(S_block_ptr)
# q, k decay
off_block = tl.arange(
    0, BLOCK
)  # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent
q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None])
                  ^

TypeError("unhashable type: 'tensor'")

Doraemonzzz commented 7 months ago

Hello, can you show the triton version by the following command:

pip list | grep triton

By the way, can you show what is the shape of s?

caihaihua057200 commented 7 months ago

triton 2.0.0 triton-nightly 2.1.0.dev20230728172942 s = _build_slope_tensor(h).to(q.device).to(torch.float32) torch.Size([12, 1, 1]) I can't print this 's' [q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None])]

Doraemonzzz commented 7 months ago

Can you share the hardware you use? The code has only been test under A100/A800.

XuyangShen commented 7 months ago

btw, could you please try torch.cuda.is_available() and torch.cuda.get_device_name(0) this look like a weird error

caihaihua057200 commented 7 months ago

torch.cuda.is_available() True torch.cuda.get_device_name(0) NVIDIA GeForce RTX 3090 Thank you for the reply. I indeed don't have A100/A800.

Doraemonzzz commented 7 months ago

I think this promblem is about shared memory, can you try the solution in issue?

Doraemonzzz commented 7 months ago

Hello, has the problem been solved?