triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.68k stars 1.53k forks source link

Proposal for device-agnostic autotuner #4417

Open int3 opened 1 month ago

int3 commented 1 month ago

The autotuner is currently tied closely to the CUDA backend: it takes a bunch of CUDA-specific parameters and then passes them to do_bench or do_bench_cudagraph, which both call into many CUDA-specific functions.

I think a cleaner API would be to have @autotune take do_bench as a parameter, and have each backend implement their own backend-specific do_bench. This would also remove the need for the use_cuda_graph parameter that autotune currently has.

We can actually go one step further to simplify the common case: most backends will have just one preferred implementation of do_bench. So we can have each backend register / declare a default do_bench implementation, then have the autotuner inspect the device of its input tensor arguments and pick the appropriate implementation.

Since this will require a decent bit of refactoring, I wanted to put this up for discussion first. I'm also wondering what the approach is for breaking changes to user-facing APIs -- can I remove the warmup, rep, and use_cuda_graph parameters now, or should we deprecate them and remove them at a later date?

xuzhao9 commented 1 month ago

cc @nmacchioni, maybe we should consolidate our effort with inductor benchmarker: https://github.com/pytorch/pytorch/pull/130926, and avoid fragmentations of benchmarking scripts in Inductor and Triton autotuning.

nmacchioni commented 1 month ago

I would disagree if removing use_cuda_graph meant always defaulting to using cudagraphs for benchmarking on the Nvidia backend. Benchmarking using cudagraphs is very slow, and in my opinion only useful in specific cases (i.e. getting more accurate perf numbers for non-autotuning cases).

int3 commented 1 month ago

I have no intention of changing the current behavior wrt cudagraphs; it's orthogonal to making the autotuner device-agnostic.

nmacchioni commented 1 month ago

Another question then, why remove warmup and rep?

int3 commented 1 month ago

They are do_bench-specific parameters, so if we are parameterizing autotune over do_bench, we can just pass them directly to do_bench:

@autotune(
    ...,
    do_bench=lambda fn: triton.testing.do_bench(fn, warmup=25, rep=100)
)
...

The default values for the current do_bench implementation are the same as the default values for the current autotune implementation, so I expect that in most cases we can continue to omit all these parameters.

int3 commented 1 month ago

Bump on this -- I'm particularly interested in whether there are objections to outright removing the warmup and rep parameters, or if we need a gradual deprecation process. @ptillet @Jokeren @ThomasRaoux

AnthonyBarbier commented 1 month ago

Bump on this -- I'm particularly interested in whether there are objections to outright removing the warmup and rep parameters, or if we need a gradual deprecation process. @ptillet @Jokeren @ThomasRaoux

IMO you shouldn't just remove them: do_bench is part of the public API and in the PyTorch codebase alone there are several hits (Including one in the code generated by the Triton inductor backend)

int3 commented 1 month ago

ISTR Triton having quite a few breaking changes in the past, though. E.g. the relocation of the libdevice code (and more I can't recall). PyTorch pins the version of Triton it uses and updates it accordingly; I would hope that other Triton consumers do the same?

AnthonyBarbier commented 1 month ago

ISTR Triton having quite a few breaking changes in the past, though. E.g. the relocation of the libdevice code (and more I can't recall). PyTorch pins the version of Triton it uses and updates it accordingly; I would hope that other Triton consumers do the same?

I can't believe I need to say it but the fact that you've made the HW vendors' life miserable on multiple occasions in the past should push you to improve your process not make a habit of it...

I'm guessing most HW vendors are organised the same way we are: one team working on Triton and a separate team working on PyTorch integration / eager mode support. And realistically as there is no decoupling between Triton core and its "backends" the people working on the Triton backend will be keen on keeping their fork as close to Triton's HEAD as possible if they want to stand a chance to upstream their backend in the future. In parallel to that, the team in charge of the PyTorch integration (Sadly it's mine...) will have to deal with whatever version of Triton the first team picked to get all the pytorch tests to run.

I mean in practice we're stuck:

Until everybody gets a chance to catch up and upstream their work it would be good to at least try to keep the API stable (And device agnostic if you genuinely want backends other than cuda to work)

int3 commented 1 month ago

I can't believe I need to say it but the fact that you've made the HW vendors' life miserable on multiple occasions in the past should push you to improve your process not make a habit of it...

There's no "you" here, this is my first API-breaking proposal for Triton, and I'm trying to understand what the general policy / approach is towards API stability. I'd previously worked more on PyTorch and dealt with the friction of Triton API changes there too.

In parallel to that, the team in charge of the PyTorch integration (Sadly it's mine...) will have to deal with whatever version of Triton the first team picked to get all the pytorch tests to run.

Thanks, this is the perspective I was lacking. At Meta, whomever changes the Triton version would also have to get the PyTorch tests to run. Our Triton backend development is also decoupled from the Triton version that PyTorch uses. Now that I understand the broader impact, I'm happy to make this change backwards-compatible. Thanks for the input!

AnthonyBarbier commented 1 month ago

There's no "you" here, this is my first API-breaking proposal for Triton,

Sorry, I meant "Triton" in this context, not you personally.

I think the tricky part with the PyTorch Triton pin (Like for some other of PyTorch third party dependencies) is that it affects a (growing) number of different device types and so every update is challenging because if it breaks or causes a regression somewhere it will be reverted (OneDNN for example have had similar issues in the past) and so as a result the pin doesn't get updated very often.

Maybe we should try to keep a rolling "triton-next" branch in PyTorch which would track Triton's head so that at least we can share the burden when things break?

nmacchioni commented 1 month ago

I think the tricky part with the PyTorch Triton pin (Like for some other of PyTorch third party dependencies) is that it affects a (growing) number of different device types and so every update is challenging because if it breaks or causes a regression somewhere it will be reverted (OneDNN for example have had similar issues in the past) and so as a result the pin doesn't get updated very often.

A lot of folks (everyone haha) on our end has experienced difficulties in this regard. There is simply no easy way to do a pin update (this is not anybody's fault, it is just the nature of the problem). Some issues are simple to diagnose and fix (i.e. straight breakages) but these issues are very limited in number. The real crux of the problem is the many subtle issues that can reveal themselves when Triton is really stress tested, such as it is when used in a library such as PyTorch. These subtle issues can cause performance degradation, numerical accuracy issues, and other failures that are very time consuming to root cause and fix. Hence, why we lag behind in our pin updates.

Maybe we should try to keep a rolling "triton-next" branch in PyTorch which would track Triton's head so that at least we can share the burden when things break?

This would probably help catch the simpler issues, which is nice, but I think the PyTorch OSS tests simply do not cover enough ground to catch the more niche breakages. So, I'd imagine even if we had some up-to-date tracking of Triton's head we'd still end up with a sizable amount of issues to fix at pin update time.