jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.16k stars 2.76k forks source link

Build Fail with invalid numeric argument '/Wno-error' #14369

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

Description

Using the following build command

python .\build\build.py --enable_cuda --cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" --cudnn_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" --cuda_compute_capabilities="7.5" --cuda_version="11.7" --cudnn_version="8.4.0" --noenable_rocm --noenable_tpu

the build fails with a series of invalid numeric argument '/Wno-error' errors as follows.

C:/users/adam/_bazel_adam/ewkz5nyk/external/triton/BUILD:46:11: Compiling lib/codegen/analysis/align.cc failed: (Exit 2): python.exe failed: error executing command
  cd /d C:/users/adam/_bazel_adam/ewkz5nyk/execroot/__main__
  SET CUDA_TOOLKIT_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7
    SET CUDNN_INSTALL_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7
    SET INCLUDE=C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\ATLMFC\include;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\include;C:\Program Files (x86)\Windows Kits\NETFXSDK\4.8\include\um;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\ucrt;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\shared;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\um;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\winrt;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\cppwinrt
    SET LIB=C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\ATLMFC\lib\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\lib\x64;C:\Program Files (x86)\Windows Kits\NETFXSDK\4.8\lib\um\x64;C:\Program Files (x86)\Windows Kits\10\lib\10.0.18362.0\ucrt\x64;C:\Program Files (x86)\Windows Kits\10\lib\10.0.18362.0\um\x64
    SET PATH=C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\\Extensions\Microsoft\IntelliCode\CLI;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\bin\HostX64\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\VC\VCPackages;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\TestWindow;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\TeamFoundation\Team Explorer;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Current\bin\Roslyn;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Team Tools\Performance Tools\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Team Tools\Performance Tools;C:\Program Files (x86)\Microsoft Visual Studio\Shared\Common\VSPerfCollectionTools\vs2019\\x64;C:\Program Files (x86)\Microsoft Visual Studio\Shared\Common\VSPerfCollectionTools\vs2019\;C:\Program Files (x86)\Microsoft SDKs\Windows\v10.0A\bin\NETFX 4.8 Tools\x64\;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\devinit;C:\Program Files (x86)\Windows Kits\10\bin\10.0.18362.0\x64;C:\Program Files (x86)\Windows Kits\10\bin\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\\MSBuild\Current\Bin;C:\Windows\Microsoft.NET\Framework64\v4.0.30319;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\;;C:\WINDOWS\system32;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja
    SET PWD=/proc/self/cwd
    SET RUNFILES_MANIFEST_ONLY=1
    SET TEMP=C:\Users\Adam\AppData\Local\Temp
    SET TF_CUDA_COMPUTE_CAPABILITIES=7.5
    SET TF_CUDA_PATHS=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7
    SET TF_CUDA_VERSION=11.7
    SET TF_CUDNN_VERSION=8.4.0
    SET TMP=C:\Users\Adam\AppData\Local\Temp
  C:\Users\Adam\anaconda3\envs\jax_latest\python.exe -B external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py /nologo /DCOMPILER_MSVC /DNOMINMAX /D_WIN32_WINNT=0x0600 /D_CRT_SECURE_NO_DEPRECATE /D_CRT_SECURE_NO_WARNINGS /D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS /bigobj /Zm500 /J /Gy /GF /EHsc /wd4351 /wd4291 /wd4250 /wd4996 /Iexternal/triton /Ibazel-out/x64_windows-opt/bin/external/triton /Iexternal/llvm-project /Ibazel-out/x64_windows-opt/bin/external/llvm-project /Iexternal/llvm_terminfo /Ibazel-out/x64_windows-opt/bin/external/llvm_terminfo /Iexternal/llvm_zlib /Ibazel-out/x64_windows-opt/bin/external/llvm_zlib /Ibazel-out/x64_windows-opt/bin/external/llvm-project/llvm/_virtual_includes/InstCombineTableGen /Iexternal/triton/include /Ibazel-out/x64_windows-opt/bin/external/triton/include /Iexternal/llvm-project/llvm/include /Ibazel-out/x64_windows-opt/bin/external/llvm-project/llvm/include /D_CRT_SECURE_NO_DEPRECATE /D_CRT_SECURE_NO_WARNINGS /D_CRT_NONSTDC_NO_DEPRECATE /D_CRT_NONSTDC_NO_WARNINGS /D_SCL_SECURE_NO_DEPRECATE /D_SCL_SECURE_NO_WARNINGS /DUNICODE /D_UNICODE /DLTDL_SHLIB_EXT=".dll" /DLLVM_PLUGIN_EXT=".dll" /DLLVM_NATIVE_ARCH="X86" /DLLVM_NATIVE_ASMPARSER=LLVMInitializeX86AsmParser /DLLVM_NATIVE_ASMPRINTER=LLVMInitializeX86AsmPrinter /DLLVM_NATIVE_DISASSEMBLER=LLVMInitializeX86Disassembler /DLLVM_NATIVE_TARGET=LLVMInitializeX86Target /DLLVM_NATIVE_TARGETINFO=LLVMInitializeX86TargetInfo /DLLVM_NATIVE_TARGETMC=LLVMInitializeX86TargetMC /DLLVM_NATIVE_TARGETMCA=LLVMInitializeX86TargetMCA /DLLVM_HOST_TRIPLE="x86_64-pc-win32" /DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-win32" /DLLVM_VERSION_MAJOR=17 /DLLVM_VERSION_MINOR=0 /DLLVM_VERSION_PATCH=0 /DLLVM_VERSION_STRING="17.0.0git" /D__STDC_LIMIT_MACROS /D__STDC_CONSTANT_MACROS /D__STDC_FORMAT_MACROS /DBLAKE3_USE_NEON=0 /DBLAKE3_NO_AVX2 /DBLAKE3_NO_AVX512 /DBLAKE3_NO_SSE2 /DBLAKE3_NO_SSE41 /showIncludes /MD /O2 /DNDEBUG /D_USE_MATH_DEFINES -DWIN32_LEAN_AND_MEAN -DNOGDI /Zc:preprocessor -DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. /std:c++17 -fexceptions -Wno-error -Wno-non-virtual-dtor -Wno-reorder-ctor -Wno-unused-variable -Wno-braced-scalar-init /Fobazel-out/x64_windows-opt/bin/external/triton/_objs/codegen/align.obj /c external/triton/lib/codegen/analysis/align.cc
# Configuration: 8cec30d85ae15acaaae5944b6bcc30efc40becd3d6a22599d268822a3b6f8357
# Execution platform: @local_execution_config_platform//:platform
cl : Command line error D8021 : invalid numeric argument '/Wno-error'
ERROR: C:/users/adam/_bazel_adam/ewkz5nyk/external/triton/BUILD:46:11: Compiling lib/codegen/transform/reorder.cc failed: (Exit 2): python.exe failed: error executing command
  cd /d C:/users/adam/_bazel_adam/ewkz5nyk/execroot/__main__
  SET CUDA_TOOLKIT_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7
    SET CUDNN_INSTALL_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7
    SET INCLUDE=C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\ATLMFC\include;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\include;C:\Program Files (x86)\Windows Kits\NETFXSDK\4.8\include\um;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\ucrt;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\shared;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\um;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\winrt;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\cppwinrt
    SET LIB=C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\ATLMFC\lib\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\lib\x64;C:\Program Files (x86)\Windows Kits\NETFXSDK\4.8\lib\um\x64;C:\Program Files (x86)\Windows Kits\10\lib\10.0.18362.0\ucrt\x64;C:\Program Files (x86)\Windows Kits\10\lib\10.0.18362.0\um\x64
    SET PATH=C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\\Extensions\Microsoft\IntelliCode\CLI;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\bin\HostX64\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\VC\VCPackages;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\TestWindow;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\TeamFoundation\Team Explorer;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Current\bin\Roslyn;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Team Tools\Performance Tools\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Team Tools\Performance Tools;C:\Program Files (x86)\Microsoft Visual Studio\Shared\Common\VSPerfCollectionTools\vs2019\\x64;C:\Program Files (x86)\Microsoft Visual Studio\Shared\Common\VSPerfCollectionTools\vs2019\;C:\Program Files (x86)\Microsoft SDKs\Windows\v10.0A\bin\NETFX 4.8 Tools\x64\;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\devinit;C:\Program Files (x86)\Windows Kits\10\bin\10.0.18362.0\x64;C:\Program Files (x86)\Windows Kits\10\bin\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\\MSBuild\Current\Bin;C:\Windows\Microsoft.NET\Framework64\v4.0.30319;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\;;C:\WINDOWS\system32;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja
    SET PWD=/proc/self/cwd
    SET RUNFILES_MANIFEST_ONLY=1
    SET TEMP=C:\Users\Adam\AppData\Local\Temp
    SET TF_CUDA_COMPUTE_CAPABILITIES=7.5
    SET TF_CUDA_PATHS=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7
    SET TF_CUDA_VERSION=11.7
    SET TF_CUDNN_VERSION=8.4.0
    SET TMP=C:\Users\Adam\AppData\Local\Temp
  C:\Users\Adam\anaconda3\envs\jax_latest\python.exe -B external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py /nologo /DCOMPILER_MSVC /DNOMINMAX /D_WIN32_WINNT=0x0600 /D_CRT_SECURE_NO_DEPRECATE /D_CRT_SECURE_NO_WARNINGS /D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS /bigobj /Zm500 /J /Gy /GF /EHsc /wd4351 /wd4291 /wd4250 /wd4996 /Iexternal/triton /Ibazel-out/x64_windows-opt/bin/external/triton /Iexternal/llvm-project /Ibazel-out/x64_windows-opt/bin/external/llvm-project /Iexternal/llvm_terminfo /Ibazel-out/x64_windows-opt/bin/external/llvm_terminfo /Iexternal/llvm_zlib /Ibazel-out/x64_windows-opt/bin/external/llvm_zlib /Ibazel-out/x64_windows-opt/bin/external/llvm-project/llvm/_virtual_includes/InstCombineTableGen /Iexternal/triton/include /Ibazel-out/x64_windows-opt/bin/external/triton/include /Iexternal/llvm-project/llvm/include /Ibazel-out/x64_windows-opt/bin/external/llvm-project/llvm/include /D_CRT_SECURE_NO_DEPRECATE /D_CRT_SECURE_NO_WARNINGS /D_CRT_NONSTDC_NO_DEPRECATE /D_CRT_NONSTDC_NO_WARNINGS /D_SCL_SECURE_NO_DEPRECATE /D_SCL_SECURE_NO_WARNINGS /DUNICODE /D_UNICODE /DLTDL_SHLIB_EXT=".dll" /DLLVM_PLUGIN_EXT=".dll" /DLLVM_NATIVE_ARCH="X86" /DLLVM_NATIVE_ASMPARSER=LLVMInitializeX86AsmParser /DLLVM_NATIVE_ASMPRINTER=LLVMInitializeX86AsmPrinter /DLLVM_NATIVE_DISASSEMBLER=LLVMInitializeX86Disassembler /DLLVM_NATIVE_TARGET=LLVMInitializeX86Target /DLLVM_NATIVE_TARGETINFO=LLVMInitializeX86TargetInfo /DLLVM_NATIVE_TARGETMC=LLVMInitializeX86TargetMC /DLLVM_NATIVE_TARGETMCA=LLVMInitializeX86TargetMCA /DLLVM_HOST_TRIPLE="x86_64-pc-win32" /DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-win32" /DLLVM_VERSION_MAJOR=17 /DLLVM_VERSION_MINOR=0 /DLLVM_VERSION_PATCH=0 /DLLVM_VERSION_STRING="17.0.0git" /D__STDC_LIMIT_MACROS /D__STDC_CONSTANT_MACROS /D__STDC_FORMAT_MACROS /DBLAKE3_USE_NEON=0 /DBLAKE3_NO_AVX2 /DBLAKE3_NO_AVX512 /DBLAKE3_NO_SSE2 /DBLAKE3_NO_SSE41 /showIncludes /MD /O2 /DNDEBUG /D_USE_MATH_DEFINES -DWIN32_LEAN_AND_MEAN -DNOGDI /Zc:preprocessor -DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. /std:c++17 -fexceptions -Wno-error -Wno-non-virtual-dtor -Wno-reorder-ctor -Wno-unused-variable -Wno-braced-scalar-init /Fobazel-out/x64_windows-opt/bin/external/triton/_objs/codegen/reorder.obj /c external/triton/lib/codegen/transform/reorder.cc
# Configuration: 8cec30d85ae15acaaae5944b6bcc30efc40becd3d6a22599d268822a3b6f8357
# Execution platform: @local_execution_config_platform//:platform
cl : Command line error D8021 : invalid numeric argument '/Wno-error'
Target //build:build_wheel failed to build
INFO: Elapsed time: 5269.619s, Critical Path: 469.96s
INFO: 9055 processes: 3120 internal, 5935 local.
FAILED: Build did NOT complete successfully
FAILED: Build did NOT complete successfully
b''
Traceback (most recent call last):
  File "C:\sdks\jax-jaxlib-v0.4.3\build\build.py", line 567, in <module>
    main()
  File "C:\sdks\jax-jaxlib-v0.4.3\build\build.py", line 562, in main
    shell(command)
  File "C:\sdks\jax-jaxlib-v0.4.3\build\build.py", line 53, in shell
    output = subprocess.check_output(cmd)
  File "C:\Users\Adam\anaconda3\envs\jax_latest\lib\subprocess.py", line 424, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "C:\Users\Adam\anaconda3\envs\jax_latest\lib\subprocess.py", line 528, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['.\\bazel-5.1.1-windows-x86_64.exe', 'run', '--verbose_failures=true', ':build_wheel', '--', '--output_path=C:\\sdks\\jax-jaxlib-v0.4.3\\dist', '--cpu=AMD64']' returned non-zero exit status 1.

What jax/jaxlib version are you using?

jaxlib v0.4.3, jax 0.4.3

Which accelerator(s) are you using?

GPU

Additional system info

Windows 10, Python 3.9, Cuda 11.7, Cudnn 8.4.0

NVIDIA GPU info

No response

cloudhan commented 1 year ago

This problem is caused by OpenXLA https://github.com/openxla/triton/blob/c3f7b6e297eac767bf07295410cf959b01fe954f/BUILD#L38-L40 and https://github.com/openxla/triton/blob/c3f7b6e297eac767bf07295410cf959b01fe954f/BUILD#L92-L99 and https://github.com/openxla/triton/blob/c3f7b6e297eac767bf07295410cf959b01fe954f/BUILD#L122-L124 tries to blindly add copts. They don't accept PR. There is no way to fix it.

adam-hartshorne commented 1 year ago

What is Triton being used for by JAX? I thought it was an alternative to CUDA for writing efficient mathematical primitives and is accessible to JAX users via jax-trition. I wonder if it would be possible to make it an optional component of JAX build?

hawkinsp commented 1 year ago

@cloudhan I think that's an oversight, I'm following up with the owners of the OpenXLA Triton fork. (For almost all changes, you should send them to upstream Triton, but the Bazel BUILD files aren't upstream so they have to be patched in the fork).

@adam-hartshorne XLA is using Triton internally on GPU for code generation of certain fusions. So it's not easy to make it optional and even if we could we'd regress performance for some models.

hawkinsp commented 1 year ago

@cloudhan If you share what fixes are needed, I can apply them for now.

cloudhan commented 1 year ago

@hawkinsp Even with some select to discard thoes copts, there still will be compiling error as follows:

external/triton/lib/codegen/selection/generator.cc(625): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or       'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(625): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(627): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or       'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(627): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(632): error C2668: 'llvm::IRBuilderBase::CreateExtractElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2348): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateExtractElement(llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2341): note: or       'llvm::Value *llvm::IRBuilderBase::CreateExtractElement(llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(632): note: while trying to match the argument list '(triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(729): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or       'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(729): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(731): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or       'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(731): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(736): error C2668: 'llvm::IRBuilderBase::CreateExtractElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2348): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateExtractElement(llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2341): note: or       'llvm::Value *llvm::IRBuilderBase::CreateExtractElement(llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(736): note: while trying to match the argument list '(triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(891): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or       'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(891): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(989): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or       'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(989): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(994): error C2668: 'llvm::IRBuilderBase::CreateInsertElement': ambiguous call to overloaded function
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2370): note: could be 'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,uint64_t,const llvm::Twine &)'
external/llvm-project/llvm/include\\llvm/IR/IRBuilder.h(2363): note: or       'llvm::Value *llvm::IRBuilderBase::CreateInsertElement(llvm::Value *,llvm::Value *,llvm::Value *,const llvm::Twine &)'
external/triton/lib/codegen/selection/generator.cc(994): note: while trying to match the argument list '(triton::codegen::Value *, triton::codegen::Value *, int)'
external/triton/lib/codegen/selection/generator.cc(1416): error C2059: syntax error: '__asm'
external/triton/lib/codegen/selection/generator.cc(1432): error C2143: syntax error: missing ')' before '__asm'
external/triton/lib/codegen/selection/generator.cc(1432): error C2661: 'llvm::IRBuilderBase::CreateCall': no overloaded function takes 0 arguments
external/triton/lib/codegen/selection/generator.cc(1432): error C2143: syntax error: missing ';' before '__asm'
external/triton/lib/codegen/selection/generator.cc(1432): error C4235: nonstandard extension used: '__asm' keyword not supported on this architecture
external/triton/lib/codegen/selection/generator.cc(1432): error C2059: syntax error: ','
external/triton/lib/codegen/selection/generator.cc(1432): error C2059: syntax error: ')'

So, it will need some patches toward trition upstream. I might take some look when I am free.

cloudhan commented 1 year ago

If all those fixed, you will get another link error as follows:

LINK : warning LNK4044: unrecognized option \'/lm\'; ignored
ffi.lib(ffi.obj) : error LNK2005: "struct XLA_FFI_Stream * __cdecl xla::runtime::ffi::GetXlaFfiStream(class xla::runtime::PtrMapByType<class xla::runtime::CustomCall,16> const *,class xla::runtime::DiagnosticEngine const *)" (?GetXlaFfiStream@ffi@runtime@xla@@YAPEAUXLA_FFI_Stream@@PEBV?$PtrMapByType@VCustomCall@runtime@xla@@$0BA@@23@PEBVDiagnosticEngine@23@@Z) already defined in executable.lib(executable.obj)
   Creating library bazel-out/x64_windows-opt/bin/external/org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so.if.lib and object bazel-out/x64_windows-opt/bin/external/org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so.if.exp
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_executable.lib(gpu_executable.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'send_recv.lib(send_recv.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'llvm_gpu_backend.lib(gpu_backend_lib.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'transpose.lib(transpose.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'bfc_allocator.lib(bfc_allocator.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'nvptx_compiler_impl.lib(nvptx_compiler.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_compiler.lib(gpu_compiler.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'cpu_runtime.lib(cpu_runtime.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_helpers.lib(gpu_helpers.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pjrt_stream_executor_client.lib(pjrt_stream_executor_client.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'local_device_state.lib(local_device_state.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'profiler.lib(profiler.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'outfeed_receiver.lib(outfeed_receiver.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'py_client.lib(py_values.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'tfrt_cpu_pjrt_client.lib(tfrt_cpu_pjrt_client.obj)\'
LINK : warning LNK4217: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'allocator_registry_impl.lo.lib(cpu_allocator_impl.obj)\' in function \'"public: static void __cdecl tsl::profiler::TraceMe::InstantActivity<class <lambda_29f743e77e718fe99c3f5b22e598e942>,1>(class <lambda_29f743e77e718fe99c3f5b22e598e942> &&,int)" (??$InstantActivity@V<lambda_29f743e77e718fe99c3f5b22e598e942>@@$00@TraceMe@profiler@tsl@@SAX$$QEAV<lambda_29f743e77e718fe99c3f5b22e598e942>@@H@Z)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pmap_lib.lib(pmap_lib.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pjit.lib(pjit.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'jax_jit.lib(jax_jit.obj)\'
LINK : warning LNK4217: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'gpu_executable.lib(gpu_executable.obj)\' in function \'"public: __cdecl tsl::profiler::ScopedAnnotationT<0>::ScopedAnnotationT<0><class <lambda_aeb2b8c334a04b454d1eb165a0a6ffbd> >(class <lambda_aeb2b8c334a04b454d1eb165a0a6ffbd>)" (??$?0V<lambda_aeb2b8c334a04b454d1eb165a0a6ffbd>@@@?$ScopedAnnotationT@$0A@@profiler@tsl@@QEAA@V<lambda_aeb2b8c334a04b454d1eb165a0a6ffbd>@@@Z)\'
LINK : warning LNK4286: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'gpu_executable.lib(sequential_thunk.obj)\'
LINK : warning LNK4286: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'tracing.lib(tracing.obj)\'
bazel-out\\x64_windows-opt\\bin\\external\\org_tensorflow\\tensorflow\\compiler\\xla\\python\\xla_extension.so : fatal error LNK1169: one or more multiply defined symbols found
cloudhan commented 1 year ago

@hawkinsp This is the patches for openxla triton on 2c3853269281da6742cf469a5ca5772947d271ce 0001-Exclude-copts-for-MSVC.patch 0002-Fix-compiling-error.patch

MSVC is bitching about error C2059: syntax error: '__asm' when you try to name a variable as _asm....

cloudhan commented 1 year ago

@hawkinsp Can you coordinate a merge with previous two patches to fix triton building on Windows? openxla/triton do not have an eta of moving to triton2 and they claim they do not accept PR, so we need some way to fix it. After that, I think I can fix #14466 on my side to re-enable windows build.

adam-hartshorne commented 1 year ago

There is now a new issue with the use of Triton in JaxLib 0.4.6 (see attachment for full error readout)

external/triton/lib/Target/LLVMIR/LLVMIRTranslation.cpp(24): fatal error C1083: Cannot open include file: 'dlfcn.h': No such file or directory

Windows doesn't have the dlopen API, and thus there is no dlfcn.h header. jax_0_4_6_build_error.txt

hawkinsp commented 1 year ago

This particular issue is fixed, but we need the following patch to openxla/triton for Triton (inside OpenXLA) to build on Windows:

--- a/triton/BUILD
+++ b/triton/BUILD
@@ -58,6 +58,11 @@ config_setting(
     "//conditions:default": ["-Wno-unused-variable -Wno-parentheses"],
 })

+_no_parentheses = select({
+    ":compiler_is_msvc": [],
+    "//conditions:default": ["-Wno-parentheses"],
+})
+
 td_library(
     name = "td_files",
     srcs = glob(["include/triton/**/*.td"]),
@@ -356,7 +361,7 @@ cc_library(
     name = "TritonTransforms",
     srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]),
     hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]),
-    copts = ["-Wno-parentheses"],
+    copts = _no_parentheses,
     includes = ["include"],
     deps = [
         ":TritonDialects",

That's it, though.

hawkinsp commented 1 year ago

https://github.com/openxla/xla/commit/972cd211a24458be7a867678059a4a4652955f9f fixed this (at XLA head).