pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 124 forks source link

How to use the new autotune intoruced in https://github.com/pytorch/torchdynamo/pull/1338 #2023

Closed ayoub-louati closed 1 year ago

ayoub-louati commented 1 year ago

Hello, Please how can we use the new decorator related to the caching_autotune introduced in https://github.com/pytorch/torchdynamo/pull/1338 with a new defined kernel. Here is an example of the kernel’s signature:

def kernel_fma(
    C,  # Pointers to matrices
    ACT_INPUTS,
    A,
    B,
    bias,
    # Matrix dimensions
    M,
    N,
    K,
    CACHE_KEY_M,
    CACHE_KEY_N,
    CACHE_KEY_K,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
    # by to get the element one row down (A has M rows)
    stride_om,
    stride_on,
    stride_im,
    stride_ik,
    stride_wn,
    stride_wk,
    # Meta-parameters
    BLOCK_M: tl.constexpr,
    GROUP_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    # split k not used, not performant with activation, kept because early_config_prune is expecting it
    SPLIT_K: tl.constexpr,
    EVEN_K: tl.constexpr,
    BIAS: tl.constexpr,
    SAVE_ACT_INPUTS: tl.constexpr,
    ACTIVATION: tl.constexpr,
)

I introduced this decorator:

def autotune(configs, meta, save_cache_hook=False):
    def decorator(fn):
        return CachingAutotuner(
            # force autotune by setting save_cache_hook to False
            fn,
            meta=meta,
            configs=configs,
            save_cache_hook=save_cache_hook,
        )

    return decorator

based on this example of test: pytorch/test_torchinductor.py at fae821c2f166fccab6a3c34e293c7268f61e82ba · pytorch/pytorch · GitHub 1

But i thought it might be a better way to use the caching_autotune.

Thanks in advance,

jansel commented 1 year ago

Are you trying to use this for handwritten Triton kernels without inductor? If so, why not just use triton.autotune? There is also an AOT compilation option in Triton.

ayoub-louati commented 1 year ago

@jansel Yes, it is handwritten triton kernel without inductor, and i'm trying to use this one because as said in the PR it reduces CPU overheads when cudagraphs is disabled and the cache introduced is really interesting because it offers the ability to reuse the compiled kernels from a run to another one. Is it possible or it should be related to inductor ?

jansel commented 1 year ago

This API is internal to inductor and not intended for handwritten kernels. You may be able to adapt it to your needs, but will need to annotate the Triton signature/invariants/metadata manually and will have no backward compatibility guarantees.

Inductor generates the needed metadata here: https://github.com/pytorch/pytorch/blob/d41b5d7c145f3e09c7223c2b707933266241ec9b/torch/_inductor/codegen/triton.py#L1063 which relies on some compiler analysis.