triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.1k stars 1.6k forks source link

Unable to run the GEMM example #1553

Open cctry opened 1 year ago

cctry commented 1 year ago

Hi,

I compiled the triton from source and tried to run the GEMM example. The compiler shows the error.

    # You can fuse arbitrary activation functions here
    # while the accumulator is still in FP32!
    if ACTIVATION == "leaky_relu":
        accumulator = leaky_relu(accumulator)
    c = accumulator.to(tl.float16)
        ^
TypeError("unhashable type: 'tensor'")

Removing this type conversion makes the code compilable.

I am using Python 3.7 and the latest commit of the triton source is fef8150b653ccbafd741443afb9dcff4b84e7c78

zhaozhixu commented 1 year ago

Same issue here. But strangely, I can only compile the original 03-matrix-multiplication.py on T4 GPU, but cannot on A10 or V100 GPU. (It's not GPU specific, but python 3.7 specific. See comments below.)

Traceback (most recent call last):
  File "<string>", line 22, in matmul_kernel
KeyError: ('2-.-1-.-0-83ca8b715a9dc5f32dc1110973485f64-f0f999f0d4d23228158479f1c3b45a40-e66911055e86008984ad15bfbb8dd5b7-f32ad5a69e3017a5098d3d5532499369-cacecb5a01b695fe1eb376e18972d557-06b47813aaed5d9f42e68c0b9c8e48c0-f9f2894cae3ab046c5bb3dc5c287118f-8723f1b05f5105fe82e1152481e8391b-28cf33ff2045e5ba4563704dbc5615cc-84d6008a7b72daa9acf1b72e2349b605-293a6a48c2635776c9dee8b54603306f-9143a91f622bbab781d8d1b2e64f6dbb-22daff6db48896bd8b64e2af96e88d58', (torch.float16, torch.float16, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 256, 64, 8, ''), (True, True, True, (True, False), (True, False), (True, False), (True, False), (False, True), (True, False), (False, True), (True, False), (False, True)), 8, 3, False)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/code_generator.py", line 1019, in ast_to_ttir
    generator.visit(fn.parse())
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/code_generator.py", line 919, in visit
    return super().visit(node)
  File "/data00/zhaozhixu/.local/lib/python3.7/ast.py", line 262, in visit
    return visitor(node)
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/code_generator.py", line 216, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/data00/zhaozhixu/.local/lib/python3.7/ast.py", line 270, in generic_visit
    self.visit(item)
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/code_generator.py", line 919, in visit
    return super().visit(node)
  File "/data00/zhaozhixu/.local/lib/python3.7/ast.py", line 262, in visit
    return visitor(node)
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/code_generator.py", line 285, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/code_generator.py", line 154, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/code_generator.py", line 919, in visit
    return super().visit(node)
  File "/data00/zhaozhixu/.local/lib/python3.7/ast.py", line 262, in visit
    return visitor(node)
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/code_generator.py", line 337, in visit_Assign
    values = self.visit(node.value)
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/code_generator.py", line 919, in visit
    return super().visit(node)
  File "/data00/zhaozhixu/.local/lib/python3.7/ast.py", line 262, in visit
    return visitor(node)
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/code_generator.py", line 837, in visit_Call
    static_implementation = self.statically_implemented_functions.get(fn)
TypeError: unhashable type: 'tensor'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "python/tutorials/03-matrix-multiplication.py", line 300, in <module>
    triton_output = matmul(a, b)
  File "python/tutorials/03-matrix-multiplication.py", line 286, in matmul
    ACTIVATION=activation
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/runtime/autotuner.py", line 98, in run
    for config in pruned_configs}
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/runtime/autotuner.py", line 98, in <dictcomp>
    for config in pruned_configs}
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/runtime/autotuner.py", line 80, in _bench
    return do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/testing.py", line 44, in do_bench
    fn()
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/runtime/autotuner.py", line 78, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 42, in matmul_kernel
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/compiler.py", line 465, in compile
    next_module = compile_kernel(module)
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/compiler.py", line 380, in <lambda>
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug), arch))
  File "/data00/zhaozhixu/workspace/source/triton/python/triton/compiler/code_generator.py", line 1028, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 67:8:        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        # We accumulate along the K dimension.
        accumulator += tl.dot(a, b)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    # You can fuse arbitrary activation functions here
    # while the accumulator is still in FP32!
    if ACTIVATION == "leaky_relu":
        accumulator = leaky_relu(accumulator)
    c = accumulator.to(tl.float16)
        ^
TypeError("unhashable type: 'tensor'")

And removing the type conversion c = accumulator.to(tl.float16) also makes the code compilable (even the results can still match pytorch on A10!).

I'm using Python 3.7 and on main branch, commit https://github.com/openai/triton/commit/0daee68d7196534ff6a3af427d4ce99cd615c255

ptillet commented 1 year ago

This is super weird :o It's possible that it's a python 3.7 specific issue and that our CI is not catching it because it's running a more recent python

zhaozhixu commented 1 year ago

This is super weird :o It's possible that it's a python 3.7 specific issue and that our CI is not catching it because it's running a more recent python

Can confirm it's a python 3.7 specific issue. I can compile the example successfully after upgrading to python 3.8.

(My previous comment that T4 is compilable is misleading ... because it's actually using python 3.8.)

ptillet commented 1 year ago

Since Python 3.7 will stop getting security updates in June 2023, I would recommend switching to Python 3.8 anyway :-)

zhaozhixu commented 1 year ago

Since Python 3.7 will stop getting security updates in June 2023, I would recommend switching to Python 3.8 anyway :-)

Yes, but still, it would be more convenient if we can fix it :) Since Python 3.7 is the official version coming with Debian 10, which is still the main LTS release, and is widely deployed.

chengzeyi commented 1 year ago

Since Python 3.7 will stop getting security updates in June 2023, I would recommend switching to Python 3.8 anyway :-)

Yes, but still, it would be more convenient if we can fix it :) Since Python 3.7 is the official version coming with Debian 10, which is still the main LTS release, and is widely deployed.

I have found one workaround by monkey-patching the CodeGenerator class:

import sys
import torch
from torch._prims_common import suggest_memory_format
import triton
import triton.language as tl

if not hasattr(tl, 'reduce'):
    tl.reduce = tl.reduction

if sys.version_info < (3, 8):
    from triton.compiler.code_generator import CodeGenerator

    class StaticallyImplementedFunctionsWrapper:

        def __init__(self, functions):
            self.functions = functions

        def get(self, name):
            try:
                return self.functions.get(name)
            except Exception:
                return None

    CodeGenerator.statically_implemented_functions = StaticallyImplementedFunctionsWrapper(
        CodeGenerator.statically_implemented_functions)

@triton.jit
def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
    delta = mean_2 - mean_1
    new_weight = weight_1 + weight_2
    w2_over_w = weight_2 * (1. / new_weight)
    return (
        mean_1 + delta * w2_over_w,
        m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
        new_weight,

...