triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.48k stars 1.51k forks source link

[RFC] "autotuner deja-vu" save and restore autotuner cache persistently #4020

Open bringlein opened 3 months ago

bringlein commented 3 months ago

Motivation

In our experiments and applications, the triton autotuner is key to achieve competitive or best performance (e.g. for flash attention in vLLM). Also, we learned that for more complex kernels, the autotuner needs to choose from more options.

On the contrary, providing a rich selection to the autotuner increases the overhead and pushes the latency into non-acceptable ranges. That’s why some applications add extra warmup steps (e.g. here). Nonetheless, this warmup algorithms are not ideal in many scenarios and e.g. need to be quite extensive or slow down the up-scaling of an application.

However, for many applications, especially in the domain of ML inference, the decision of the autotuner could be determined ahead of time. For example, the block sizes of a layernorm or attention kernels depend on the architecture of the LLM and the number of input tokens. Both parameters are known in advance, or at least their ranges.

These observations lead us to the conclusion, that the core problem of the autotuner overhead could be addressed by saving and restoring the autotuner cache of benchmarking runs or in a CI/CD pipeline.

Therefore, we implemented our own version of the autotuner, called triton-dejavu. This modified autotuner acts like the original autotuner, but it saves and restores already known cache states. In case the autotuner is triggered, and the restored cache does not contain the required key, a autotune run is executed, exactly as the original autotuner does.

To determine if a previously stored cache is still applicable, we use the combination of multiple values:

So far, we think that the above listed combination determines the applicability of a cache unambiguous. Hence, if all these values match a stored cache, this cache is restored and reused.

Below is a simple example of how such a stored cache looks like (the “some_function” in the identifier is just there to help us humans analyzing what’s in the cache):

DejavuStorage identifier: dejavu_0.1/cuda_12.1/torch_2.1.2+cu121/triton_3.0.0/gpu_NVIDIA_A100_80GB_PCIe 
Cache identifier: some_function-6fa5b719e15853974d99a8e237d162900916ca7ff8de80f4e8605bae5d7e7260490-6bd5c9ed739ae8203cfbfab7076483edfd180f614099956bbe45c05ad446134b-ab4979d1539b394f48fe313d5462dc9254ae1623050232bd8d11077553c70c0c
Stored cache: 
{
        "signature": "JITFunction(some_function)",
        "total_bench_time_s": 23.483010053634644,
        "evaluated_configs": 16,
        "cache": {
            "(2, 2, True, 0.0, 0, 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32')": "BLOCK_M: 64, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 1",
            "(32, 32, True, 0.0, 0, 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 1"
        }
}

Our experiments show that our dejavu-autotuner removes the additional overhead of many autotuner configurations while still profiting from the added flexibility and increase performance.

Questions

Therefore, we would like to have your comments about such a “dejavu” autotuner mechanism. We think that our (@tdoublep and mine) reasoning behind using a “pre-compiled” autotuner cache for an application does not only apply to the applications we currently consider (LLM inference services).

Hence, do you think this would be of interest to be merged into the triton project or should it continue as standalone project (which we plan to open source)? Additionally, do you think the above listed values are sufficient to decide if previously autotuner results could be reused safely?

Thanks for your feedback!

CC: @jlebar @ThomasRaoux @ptillet @Jokeren

leademeule commented 3 months ago

While I am not part of the Triton project myself, I would be very happy to see these features integrated into Triton! Even for the sake of quickly iterating and debugging, this feature would cut out a lot of annoying overhead.

ThomasRaoux commented 3 months ago

Thanks for sharing, this is very interesting. Caching the tuning results is quite important and can be hard to do.

First answering the easy question:

Additionally, do you think the above listed values are sufficient to decide if previously autotuner results could be reused safely?

I'm not sure if the options are hardcoded in your example? There may be more and it might be backend specific. Note that you may be able to relay on the existing cache_hook where you can get the "specialization_data" associated with a kernel which contains a serialization of all parameters needed to recompile the given specialization of the kernel:

https://github.com/openai/triton/blob/main/python/triton/runtime/jit.py#L349

About the main question and whether it should be upstreamed, it is hard to tell without seeing the code. If you could upload this in a fork somewhere that would be easier to judge if it makes sense. My guess at this point is that the caching solution is probably too opinionated (meaning different users would want to do it differently) to be upstreamed. There are other ways to avoid doing compiling and profiling on the flight using warmup or preload and building infrastructure on top of triton.

I wonder if this could be better to have it as a separate tool that lives on top of triton? (again it's hard for me to tell without seeing how much dependencies there are)

I'll also chat more with other core maintainers to gather more ideas about that.

bringlein commented 3 months ago

Thanks both of you for your feedback!

@ThomasRaoux :

it is hard to tell without seeing the code

Yes, of course, we are currently working through our internal processes to get this open-sourced in the next weeks. However, we wanted to share the idea and get feedback from you earlier. Based on this feedback, we will make it available as standalone tool "above" triton first. And then iterate based on this?

There may be more and it might be backend specific. Note that you may be able to relay on the existing cache_hook where you can get the "specialization_data" associated with a kernel which contains a serialization of all parameters needed to recompile the given specialization of the kernel

Thanks for the pointer, I'll look into this!

ThomasRaoux commented 3 months ago

Yes, of course, we are currently working through our internal processes to get this open-sourced in the next weeks. However, we wanted to share the idea and get feedback from you earlier. Based on this feedback, we will make it available as standalone tool "above" triton first. And then iterate based on this?

Makes sense to me.

bringlein commented 4 weeks ago

We were able to open source our implementation of a dejavu mechanism for triton earlier this week: https://github.com/IBM/triton-dejavu ( :tada: )

The implementation is very similar to the discussed above, we just consider some autotune parameters additionally as part for the hash. Additionally, the above triton-dejavu framework offers a way to specify search spaces more efficiently.

Also, as discussed, I looked into the existing cache_hook and the associated specialization_data, but it appeared to me that most of these data are determined based on an existing configuration, which is too specific for the autotuner-restore decision.

Any feedback would be appreciated!

(And yes...this internal process took far longer than expected...)

bringlein commented 2 days ago

@ThomasRaoux did you had a chance to look at it?

ThomasRaoux commented 2 days ago

@ThomasRaoux did you had a chance to look at it?

ah no I had missed the last message. I'll try to take a look in the next few days