triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.08k stars 1.6k forks source link

Regression on certain pointwise code between dev20221202 and 2.0.0a2 #1266

Open cpuhrsch opened 1 year ago

cpuhrsch commented 1 year ago

Hey all, I have the following generated kernel https://gist.github.com/cpuhrsch/df44958f644c0b979fecc999ed3187ee and see a large regression between dev20221202 and 2.0.0a2. In particular triton0 and triton2 seem to slow down a lot with a drop in bandwidth about 3x to 4x. This is on an A100 and using a PyTorch nightly. I’m only changing my environment by installing different Triton versions (see comments on the gist).

I was encouraged to post this here after discussion on slack. cc @Jokeren

In the meanwhile I'll try different versions of Triton and various commits.

EDIT: Running it on Triton installed using commit be6217cce seems to cause a segfault.

Jokeren commented 1 year ago

Running it on Triton installed using commit be6217cce seems to cause a segfault.

That's not expected. Have you pip uninstall pytorch-triton -y?

If still segment fault, can you copy and paste the stack trace?

cpuhrsch commented 1 year ago

@Jokeren, sure I installed Triton with CXX=clang++-10 CC=clang-10 TRITON_USE_ASSERT_ENABLED_LLVM=False pip install -v -e . from within the python folder.

Then I'm running the code in the gist and get

$ python ~/tmp/a_dev200a2.py
realloc(): invalid pointer
zsh: abort (core dumped)  python ~/tmp/a_dev200a2.py

The stacktrace doesn't look very relevant to Triton.

Thread 1 "python" hit Catchpoint 1 (exception thrown), __cxxabiv1::__cxa_throw (obj=0x634b420, tinfo=0x7fffc16abff0 <typeinfo for c10::Error>, dest=0x7fff7356dc70 <c10::Error::~Error()>)
    at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/build/gcc-final/x86_64-conda-linux-gnu/libstdc++-v3/libsupc++/eh_throw.cc:77
77      /opt/conda/conda-bld/gcc-compiler_1654084175708/work/build/gcc-final/x86_64-conda-linux-gnu/libstdc++-v3/libsupc++/eh_throw.cc: No such file or directory.
#0  __cxxabiv1::__cxa_throw (obj=0x634b420, tinfo=0x7fffc16abff0 <typeinfo for c10::Error>, dest=0x7fff7356dc70 <c10::Error::~Error()>)
    at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/build/gcc-final/x86_64-conda-linux-gnu/libstdc++-v3/libsupc++/eh_throw.cc:77
#1  0x00007fff7356a90f in c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) () from /scratch/cpuhrsch/miniconda3/envs/nightly20230301py310/lib/python3.10/site-packages/torch/lib/libc10.so
#2  0x00007fffc0f3a035 in torch::jit::initJITBindings(_object*)::{lambda(std::string const&)#193}::operator()(std::string const&) const [clone .isra.0] ()
   from /scratch/cpuhrsch/miniconda3/envs/nightly20230301py310/lib/python3.10/site-packages/torch/lib/libtorch_python.so
#3  0x00007fffc0f3a24c in void pybind11::cpp_function::initialize<torch::jit::initJITBindings(_object*)::{lambda(std::string const&)#193}, pybind11::tuple, std::string const&, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg>(torch::jit::initJITBindings(_object*)::{lambda(std::string const&)#193}&&, pybind11::tuple (*)(std::string const&), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call) () from /scratch/cpuhrsch/miniconda3/envs/nightly20230301py310/lib/python3.10/site-packages/torch/lib/libtorch_python.so
#4  0x00007fffc0b53d55 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from /scratch/cpuhrsch/miniconda3/envs/nightly20230301py310/lib/python3.10/site-packages/torch/lib/libtorch_python.so
#5  0x00000000004fee27 in cfunction_call () at /usr/local/src/conda/python-3.10.9/Objects/pycore_bitutils.h:543

I'll recompile PyTorch from source with debug symbols and try to get you a better trace.

cpuhrsch commented 1 year ago

Nope, it's still as opaque. Trying with CUDA_LAUNCH_BLOCKING=1 next. EDIT: Doesn't bring more clarity either.

Thread 1 "python" hit Catchpoint 1 (exception thrown), __cxxabiv1::__cxa_throw (obj=0x61ecec0, tinfo=0x7fffc68c18c0 <typeinfo for c10::Error>, dest=0x7fff996bed08 <c10::Error::~Error()>)
    at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/build/gcc-final/x86_64-conda-linux-gnu/libstdc++-v3/libsupc++/eh_throw.cc:77
77      /opt/conda/conda-bld/gcc-compiler_1654084175708/work/build/gcc-final/x86_64-conda-linux-gnu/libstdc++-v3/libsupc++/eh_throw.cc: No such file or directory.
#0  __cxxabiv1::__cxa_throw (obj=0x61ecec0, tinfo=0x7fffc68c18c0 <typeinfo for c10::Error>, dest=0x7fff996bed08 <c10::Error::~Error()>)
    at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/build/gcc-final/x86_64-conda-linux-gnu/libstdc++-v3/libsupc++/eh_throw.cc:77
#1  0x00007fff99710ec0 in c10::detail::torchCheckFail (
    func=0x7fffc5ef9778 <torch::jit::initJITBindings(_object*)::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)#194}::operator()(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) const::__func__> "operator()", file=0x7fffc5ef9448 "/scratch/cpuhrsch/pytorch/torch/csrc/jit/python/init.cpp", line=1575, msg=...) at /scratch/cpuhrsch/pytorch/c10/util/Exception.cpp:87
#2  0x00007fffc5652152 in torch::jit::<lambda(const string&)>::operator()(const std::__cxx11::string &) const (__closure=0x3fc1ec8, op_name=...) at /scratch/cpuhrsch/pytorch/torch/csrc/jit/python/init.cpp:1575
#3  0x00007fffc56b4ebd in pybind11::detail::argument_loader<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&>::call_impl<pybind11::tuple, torch::jit::initJITBindings(PyObject*)::<lambda(const string&)>&, 0, pybind11::detail::void_type>(torch::jit::<lambda(const string&)> &, std::index_sequence, pybind11::detail::void_type &&) (this=0x7fffffffa980, f=...) at /scratch/cpuhrsch/pytorch/third_party/pybind11/include/pybind11/cast.h:1439
#4  0x00007fffc56ab4f1 in pybind11::detail::argument_loader<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&>::call<pybind11::tuple, pybind11::detail::void_type, torch::jit::initJITBindings(PyObject*)::<lambda(const string&)>&>(torch::jit::<lambda(const string&)> &) (this=0x7fffffffa980, f=...) at /scratch/cpuhrsch/pytorch/third_party/pybind11/include/pybind11/cast.h:1408
#5  0x00007fffc569b25c in pybind11::cpp_function::<lambda(pybind11::detail::function_call&)>::operator()(pybind11::detail::function_call &) const (__closure=0x0, call=...)
    at /scratch/cpuhrsch/pytorch/third_party/pybind11/include/pybind11/pybind11.h:249
Jokeren commented 1 year ago

I cannot reproduce it on my end.

Can you try https://github.com/openai/triton/commit/f8c92c3d17e39239a087d951608d9c97dfa62ab9?

If it still doesn't solve problem, there maybe some library handles conflict underlying

cpuhrsch commented 1 year ago

@Jokeren - I still run into the same issue. Are you able to reproduce the slowness on a recent main commit? I'm happy to continue using dev20221202 for my development for now and wait for a new pypi release if you can't reproduce the slowness.

Jokeren commented 1 year ago

Yes, I can reproduce the slowness issue

bingo787 commented 8 months ago

I had the same problem! performace: triton-nightly < triton2.x.x < triton2.0.0.dev20221202

minjang commented 7 months ago

I had the same problem! performace: triton-nightly < triton2.x.x < triton2.0.0.dev20221202

@bingo787 Can you share your recent example that can reproduce this performance regression?

The reported code in the gist is a year old, and I was unable to reproduce on my local due to some PyTorch compile issues.

bingo787 commented 7 months ago

I had the same problem! performace: triton-nightly < triton2.x.x < triton2.0.0.dev20221202

@bingo787 Can you share your recent example that can reproduce this performance regression?

The reported code in the gist is a year old, and I was unable to reproduce on my local due to some PyTorch compile issues.

Yes, here is the test code that can reproduce this performance regression . https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py

minjang commented 7 months ago

@bingo787 Thanks for the pointer! So, when you reproduce the perf regression on your code, was this still based on the originally reported 1-yr-old versions in this issue? We use pretty much latest Triton 3.0.x. So I was wondering which versions you compared.