Closed adam-hartshorne closed 1 year ago
This problem is caused by OpenXLA https://github.com/openxla/triton/blob/c3f7b6e297eac767bf07295410cf959b01fe954f/BUILD#L38-L40 and https://github.com/openxla/triton/blob/c3f7b6e297eac767bf07295410cf959b01fe954f/BUILD#L92-L99 and https://github.com/openxla/triton/blob/c3f7b6e297eac767bf07295410cf959b01fe954f/BUILD#L122-L124 tries to blindly add copts. They don't accept PR. There is no way to fix it.
What is Triton being used for by JAX? I thought it was an alternative to CUDA for writing efficient mathematical primitives and is accessible to JAX users via jax-trition. I wonder if it would be possible to make it an optional component of JAX build?
@cloudhan I think that's an oversight, I'm following up with the owners of the OpenXLA Triton fork. (For almost all changes, you should send them to upstream Triton, but the Bazel BUILD files aren't upstream so they have to be patched in the fork).
@adam-hartshorne XLA is using Triton internally on GPU for code generation of certain fusions. So it's not easy to make it optional and even if we could we'd regress performance for some models.
@cloudhan If you share what fixes are needed, I can apply them for now.
@hawkinsp Even with some select
to discard thoes copts, there still will be compiling error as follows:
external/triton/lib/codegen/selection/generator.cc(625): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(625): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(627): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(627): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(632): error C2668: 'llvm::IRBuilderBase::CreateExtractElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2348): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateExtractElement(llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2341): note: or 'llvm::Value *llvm::IRBuilderBase::CreateExtractElement(llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(632): note: while trying to match the argument list '(triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(729): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(729): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(731): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(731): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(736): error C2668: 'llvm::IRBuilderBase::CreateExtractElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2348): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateExtractElement(llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2341): note: or 'llvm::Value *llvm::IRBuilderBase::CreateExtractElement(llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(736): note: while trying to match the argument list '(triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(891): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(891): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(989): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(989): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(994): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(994): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(1416): error C2059: syntax error: '__asm'
external/triton/lib/codegen/selection/generator.cc(1432): error C2143: syntax error: missing ')' before '__asm'
external/triton/lib/codegen/selection/generator.cc(1432): error C2661: 'llvm::IRBuilderBase::CreateCall': no overloaded function takes 0 arguments
external/triton/lib/codegen/selection/generator.cc(1432): error C2143: syntax error: missing ';' before '__asm'
external/triton/lib/codegen/selection/generator.cc(1432): error C4235: nonstandard extension used: '__asm' keyword not supported on this architecture
external/triton/lib/codegen/selection/generator.cc(1432): error C2059: syntax error: ','
external/triton/lib/codegen/selection/generator.cc(1432): error C2059: syntax error: ')'
So, it will need some patches toward trition upstream. I might take some look when I am free.
If all those fixed, you will get another link error as follows:
LINK : warning LNK4044: unrecognized option \'/lm\'; ignored
ffi.lib(ffi.obj) : error LNK2005: "struct XLA_FFI_Stream * __cdecl xla::runtime::ffi::GetXlaFfiStream(class xla::runtime::PtrMapByType<class xla::runtime::CustomCall,16> const *,class xla::runtime::DiagnosticEngine const *)" (?GetXlaFfiStream@ffi@runtime@xla@@YAPEAUXLA_FFI_Stream@@PEBV?$PtrMapByType@VCustomCall@runtime@xla@@$0BA@@23@PEBVDiagnosticEngine@23@@Z) already defined in executable.lib(executable.obj)
Creating library bazel-out/x64_windows-opt/bin/external/org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so.if.lib and object bazel-out/x64_windows-opt/bin/external/org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so.if.exp
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_executable.lib(gpu_executable.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'send_recv.lib(send_recv.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'llvm_gpu_backend.lib(gpu_backend_lib.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'transpose.lib(transpose.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'bfc_allocator.lib(bfc_allocator.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'nvptx_compiler_impl.lib(nvptx_compiler.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_compiler.lib(gpu_compiler.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'cpu_runtime.lib(cpu_runtime.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_helpers.lib(gpu_helpers.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pjrt_stream_executor_client.lib(pjrt_stream_executor_client.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'local_device_state.lib(local_device_state.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'profiler.lib(profiler.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'outfeed_receiver.lib(outfeed_receiver.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'py_client.lib(py_values.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'tfrt_cpu_pjrt_client.lib(tfrt_cpu_pjrt_client.obj)\'
LINK : warning LNK4217: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'allocator_registry_impl.lo.lib(cpu_allocator_impl.obj)\' in function \'"public: static void __cdecl tsl::profiler::TraceMe::InstantActivity<class <lambda_29f743e77e718fe99c3f5b22e598e942>,1>(class <lambda_29f743e77e718fe99c3f5b22e598e942> &&,int)" (??$InstantActivity@V<lambda_29f743e77e718fe99c3f5b22e598e942>@@$00@TraceMe@profiler@tsl@@SAX$$QEAV<lambda_29f743e77e718fe99c3f5b22e598e942>@@H@Z)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pmap_lib.lib(pmap_lib.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pjit.lib(pjit.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'jax_jit.lib(jax_jit.obj)\'
LINK : warning LNK4217: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'gpu_executable.lib(gpu_executable.obj)\' in function \'"public: __cdecl tsl::profiler::ScopedAnnotationT<0>::ScopedAnnotationT<0><class <lambda_aeb2b8c334a04b454d1eb165a0a6ffbd> >(class <lambda_aeb2b8c334a04b454d1eb165a0a6ffbd>)" (??$?0V<lambda_aeb2b8c334a04b454d1eb165a0a6ffbd>@@@?$ScopedAnnotationT@$0A@@profiler@tsl@@QEAA@V<lambda_aeb2b8c334a04b454d1eb165a0a6ffbd>@@@Z)\'
LINK : warning LNK4286: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'gpu_executable.lib(sequential_thunk.obj)\'
LINK : warning LNK4286: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'tracing.lib(tracing.obj)\'
bazel-out\\x64_windows-opt\\bin\\external\\org_tensorflow\\tensorflow\\compiler\\xla\\python\\xla_extension.so : fatal error LNK1169: one or more multiply defined symbols found
@hawkinsp This is the patches for openxla triton on 2c3853269281da6742cf469a5ca5772947d271ce 0001-Exclude-copts-for-MSVC.patch 0002-Fix-compiling-error.patch
MSVC is bitching about error C2059: syntax error: '__asm'
when you try to name a variable as _asm
....
@hawkinsp Can you coordinate a merge with previous two patches to fix triton building on Windows? openxla/triton do not have an eta of moving to triton2 and they claim they do not accept PR, so we need some way to fix it. After that, I think I can fix #14466 on my side to re-enable windows build.
There is now a new issue with the use of Triton in JaxLib 0.4.6 (see attachment for full error readout)
external/triton/lib/Target/LLVMIR/LLVMIRTranslation.cpp(24): fatal error C1083: Cannot open include file: 'dlfcn.h': No such file or directory
Windows doesn't have the dlopen API, and thus there is no dlfcn.h header. jax_0_4_6_build_error.txt
This particular issue is fixed, but we need the following patch to openxla/triton
for Triton (inside OpenXLA) to build on Windows:
--- a/triton/BUILD
+++ b/triton/BUILD
@@ -58,6 +58,11 @@ config_setting(
"//conditions:default": ["-Wno-unused-variable -Wno-parentheses"],
})
+_no_parentheses = select({
+ ":compiler_is_msvc": [],
+ "//conditions:default": ["-Wno-parentheses"],
+})
+
td_library(
name = "td_files",
srcs = glob(["include/triton/**/*.td"]),
@@ -356,7 +361,7 @@ cc_library(
name = "TritonTransforms",
srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]),
hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]),
- copts = ["-Wno-parentheses"],
+ copts = _no_parentheses,
includes = ["include"],
deps = [
":TritonDialects",
That's it, though.
https://github.com/openxla/xla/commit/972cd211a24458be7a867678059a4a4652955f9f fixed this (at XLA head).
Description
Using the following build command
python .\build\build.py --enable_cuda --cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" --cudnn_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" --cuda_compute_capabilities="7.5" --cuda_version="11.7" --cudnn_version="8.4.0" --noenable_rocm --noenable_tpu
the build fails with a series of invalid numeric argument '/Wno-error' errors as follows.
What jax/jaxlib version are you using?
jaxlib v0.4.3, jax 0.4.3
Which accelerator(s) are you using?
GPU
Additional system info
Windows 10, Python 3.9, Cuda 11.7, Cudnn 8.4.0
NVIDIA GPU info
No response