triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.35k stars 1.64k forks source link

Detecting triton version reliably in xformers #2161

Open danthe3rd opened 1 year ago

danthe3rd commented 1 year ago

Hello,

I'm a maintainer of the xFormers library, which contains many triton kernels to accelerate workloads (flash-attention, fused operations...). One issue we face is that the same triton code might work or not depending on the triton version that is installed, and it's very hard to know in advance if the triton installed includes some bugfixes or not.

Something that would help us tremendously would be to have an autoincrement version number on all releases (eg triton.__version__ = "2.1.0.devN", with N the number of commits in main for instance). Currently, we have triton.__version__ = "2.1.0" for a lot of triton versions with slightly different APIs and bugfixes.

Possible workarounds: (1) We've considered using importlib.metadata to fetch the development version installed through pip (which contains the nightly date). It's a bit convoluted because triton can be installed via either triton or triton-nightly but could work. However this approach breaks down when installing packages with conda, as the torchtriton package from pytorch-nightly only provides a commit hash for instance. This also might not work for builds from source. (2) Another approach would be to detect individual features by compiling toy kernels and comparing the PTX. This is however very hard to maintain, and would require some engineering for every operator we have in triton... Furthermore PTX has no stability guarantee

Do you have any suggestions on how we could detect the triton version reliably? Thanks a lot!

cc @lw @fmassa @plabatut @bottler

CHDev93 commented 1 year ago

Would like to +1 this as I have run into a similar problem