triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.4k stars 1.64k forks source link

Wrong source read by triton.jit in some python versions #1589

Open orsharir opened 1 year ago

orsharir commented 1 year ago

In some versions of python (details below) triton.jit will read only a partial part of the decorated function, leading to an assertion error when the function is first called, parsed with ast.parse(), and then checked (assert) that the body of ast.Module is of length 1. The user only sees the failed assert "assert len(tree.body) == 1" with no further explanation on what went wrong: https://github.com/openai/triton/blob/7d2a4d95c28ff89c5b9431f003928d2cda46a71f/python/triton/runtime/jit.py#L390 I suggest adding a basic check and warn with a more interpretable error message, that could either point to this issue or provide possible workarounds.

Details:

There's an odd bug in inspect.getsource in some versions of python (notably python 3.8 as in many docker images, but anything before 3.10.10) that causes it to return just the decorator to a function instead of its complete source. Specifically, if a decorator contains an additional nested rounded parenthesis inside its args (not always, there are exceptions). See https://github.com/python/cpython/issues/102647 for more details.

In triton.jit, inspect.getsource is used to read the source of the function before analyzing it with ast. This bug in python could be triggered for instance if you're also using triton.autotune or triton.heuristics and include such an expression.

Example

For instance, the following trigger it:

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_S1': BLOCK_S1, 'BLOCK_D': BLOCK_D},
                      num_stages=num_stages, num_warps=num_wraps)
        for BLOCK_S1 in [16, 32, 64, 128]
        for BLOCK_D in [32, 64, 128]
        for num_stages in [2, 3]
        for num_wraps in [2, 4]
    ],
    key=['B', 'H', 'S', 'D'],
    prune_configs_by={
        'early_config_prune': early_config_prune_base_vs_diff,
        'perf_model': lambda *args, **kwargs: 1.0,
        'top_k': 10
    },
)
@triton.jit
def some_func():
    ...

In this case, it was caused by simply using a list comprehension with triton.Config(). The result of calling inspect.getsource(some_func) will be the code for just th first autotune decorator, while the actual function definition will be missing. In triton.jit, the object looks for the beginning of some_func() by calling find() on the returned source and the string 'def', but that string doesn't exist tot he above bug, and so the source is truncated to a single newline character. Because it's just a single newline character, calling ast.parse returns and ast.Module with an empty body that raises the assertion error.

Workaround

There's an easy workaround for older python versions: prepare all the arguments to the decorators before using them. However, most users wouldn't know that that's their problem.

chengzeyi commented 12 months ago

I am going to decorate with eval😢

@eval('''triton.heuristics({
    'ROW_SIZE':
    lambda kwargs: triton.next_power_of_2(kwargs['C'] // kwargs['groups']),
    'BLOCK_SIZE':
    lambda kwargs: max(
        1, 4096 // (triton.next_power_of_2(kwargs['C'] // kwargs['groups']))),
})''')
@triton.jit
def group_norm_4d_channels_last_forward_collect_stats_kernel(
    input_ptr,
    N,
    C,
    HxW,
    groups,
    eps,
    mean_ptr,
    rstd_ptr,
    ROW_SIZE: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pass