Open SomeoneSerge opened 1 year ago
Extra: a configurable path to ptxas
is desirable. Maybe even just accepting ptxas from $PATH
?
Update: I also notice that pytorch seems to have tried vendoring ptxas at a point, and ceased doing so since triton became their dependency. This is probably good: if pytorch just asks some function from triton
to report the location of ptxas
, then by making ptxas optional and configurable in triton, we'd have addressed the issue in pytorch as well
The ptxas discussion can be moved to https://github.com/openai/triton/issues/1618
FWIW, I met the related issue when I use triton (2.2.0) in a conda environment. The cuda toolkit is installed in the conda env (rather than in the system), so the compiler can't find the cuda library default linking path. Below is the stacktrace
/usr/bin/ld: skipping incompatible /usr/lib/i386-linux-gnu/libcuda.so when searching for -lcuda
/usr/bin/ld: skipping incompatible /usr/lib/i386-linux-gnu/libcuda.so when searching for -lcuda
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
Traceback (most recent call last):
File "/home/jacobz/Downloads/01-vector-add.py", line 87, in <module>
output_triton = add(x, y)
File "/home/jacobz/Downloads/01-vector-add.py", line 73, in add
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
File "/home/jacobz/.conda/envs/lmdeploy-build/lib/python3.10/site-packages/triton/runtime/jit.py", line 532, in run
self.cache[device][key] = compile(
File "/home/jacobz/.conda/envs/lmdeploy-build/lib/python3.10/site-packages/triton/compiler/compiler.py", line 614, in compile
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
File "/home/jacobz/.conda/envs/lmdeploy-build/lib/python3.10/site-packages/triton/compiler/make_launcher.py", line 37, in make_stub
so = _build(name, src_path, tmpdir)
File "/home/jacobz/.conda/envs/lmdeploy-build/lib/python3.10/site-packages/triton/common/build.py", line 107, in _build
ret = subprocess.check_call(cc_cmd)
File "/home/jacobz/.conda/envs/lmdeploy-build/lib/python3.10/subprocess.py", line 369, in check_call
raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmpvdufqg5s/main.c', '-O3', '-I/home/jacobz/.conda/envs/lmdeploy-build/lib/python3.10/site-packages/triton/common/../third_party/cuda/include', '-I/home/jacobz/.conda/envs/lmdeploy-build/include/python3.10', '-I/tmp/tmpvdufqg5s', '-shared', '-fPIC', '-L/home/jacobz/.conda/envs/lmdeploy-build/targets/x86_64-linux/lib/stubs', '-lcuda', '-o', '/tmp/tmpvdufqg5s/add_kernel.cpython-310-x86_64-linux-gnu.so', '-L/usr/lib/x86_64-linux-gnu', '-L/usr/lib/i386-linux-gnu', '-L/usr/lib/i386-linux-gnu']' returned non-zero exit status 1.
I need to pass the cuda library to the compiler, and what you need is to add the cuda library path (for me it's <conda env root>/lib/stubs
) to the compiler flags here:
https://github.com/triton-lang/triton/blob/c9ab44888ed445acf7acb7d377aae98e07630015/python/triton/common/build.py#L89
The path can be automatically found with:
def conda_cuda_dir():
conda_path = os.environ['CONDA_PREFIX']
return os.path.join(conda_path, "lib", "stubs")
This specific issue is fixed in main branch, where there is a env var TRITON_LIBCUDA_PATH
for this.
Hi! I see that
openai/triton
requires a working toolchain at run-time, including a CUDAToolkit and libpython installations for the host platform. Currently, triton attempts to guess the correct compiler flags on its own: https://github.com/openai/triton/blob/deb2c71fb4f912a5298003fa3fc789885b726607/python/triton/common/build.py#L77-L82This includes inferring the library locations: https://github.com/openai/triton/blob/deb2c71fb4f912a5298003fa3fc789885b726607/python/triton/common/build.py#L19-L22
What this means, in practice, is that
openai/triton
is taking on a job that is usually performed by tools like CMake, and that certain care is to be taken when deployingopenai/triton
. The current flag inference logic is platform-specific and, of course, it isn't expected to be universal either. But we probably should work out a solution on how to make it configurable, so that e.g. distributions can set up their environments to meet triton's expectations.Some concrete examples of issues that arise:
libcuda.so
user-space driver is deployed in a special location,/run/opengl-driver/lib
, andwhereis
wouldn't produce any reasonable output because/lib
and/usr/lib
do not exit. In https://github.com/NixOS/nixpkgs/pull/222273 we end up patchingtriton/compiler.py
to pass the correct-L
flag to the compiler: https://github.com/NixOS/nixpkgs/blob/e4474334415ac41efb5fda33d4cc8f312397ef05/pkgs/development/python-modules/openai-triton/default.nix#L128-L147. We also have to work around triton trying to vendor a copy of ptxaspytorch/pytorch
there is a number of confused issues about broken-lcuda
and#include <Python.h>
An off-the-shelf way of making libpython and cuda flags configurable would be
pkg-config
, although I'd feel weird and conflicted about setting up pkg-config at run-time side by side with pytorch. I also note that this situation is somewhat similar to that oftorch.utils.cpp_extension
, which also attempts to guess build flags at run-time