ROCm / aotriton

Ahead of Time (AOT) Triton Math Library
MIT License
30 stars 12 forks source link

[Issue]: Pytorch fails to compile locally due to aotriton failing to build the hsaco objects #18

Open Zakhrov opened 4 months ago

Zakhrov commented 4 months ago

Problem Description

Pytorch fails to compile locally with aotriton, and throws the following error:

 make -j 6 -f Makefile.shim HIPCC=hipcc AR=/usr/bin/ar EXTRA_COMPILER_OPTIONS=-I/opt/rocm/include/\ \ -O3\ -DNDEBUG
hipcc -I/opt/rocm/include/  -O3 -DNDEBUG  -DAOTRITON_USE_ZSTD=0 -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/third_party/incbin -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include -fPIC -std=c++20 /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/shim.attn_fwd.cc -o /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/shim.attn_fwd.o -c -fPIC -std=c++20 
hipcc -I/opt/rocm/include/  -O3 -DNDEBUG  -DAOTRITON_USE_ZSTD=0 -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/third_party/incbin -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include -fPIC -std=c++20 /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,True,True___MI200.cc -o /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,True,True___MI200.o -c 
hipcc -I/opt/rocm/include/  -O3 -DNDEBUG  -DAOTRITON_USE_ZSTD=0 -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/third_party/incbin -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include -fPIC -std=c++20 /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,True,False___MI200.cc -o /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,True,False___MI200.o -c 
hipcc -I/opt/rocm/include/  -O3 -DNDEBUG  -DAOTRITON_USE_ZSTD=0 -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/third_party/incbin -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include -fPIC -std=c++20 /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,False,True___MI200.cc -o /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,False,True___MI200.o -c 
hipcc -I/opt/rocm/include/  -O3 -DNDEBUG  -DAOTRITON_USE_ZSTD=0 -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/third_party/incbin -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include -fPIC -std=c++20 /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,False,False___MI200.cc -o /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,False,False___MI200.o -c 
hipcc -I/opt/rocm/include/  -O3 -DNDEBUG  -DAOTRITON_USE_ZSTD=0 -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/third_party/incbin -I/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include -fPIC -std=c++20 /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,False,True,True___MI200.cc -o /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,False,True,True___MI200.o -c 
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,False,True___MI200.cc:11:
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/../shim.attn_fwd.h:6:
/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include/aotriton/_internal/triton_kernel.h:23:62: error: no template named 'vector' in namespace 'std'
   23 |   hipError_t invoke(const char* kernel_name, dim3 grid, std::vector<void*>& args, hipStream_t stream);
      |                                                         ~~~~~^
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,True,False___MI200.cc:11:
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/../shim.attn_fwd.h:6:
/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include/aotriton/_internal/triton_kernel.h:23:62: error: no template named 'vector' in namespace 'std'
   23 |   hipError_t invoke(const char* kernel_name, dim3 grid, std::vector<void*>& args, hipStream_t stream);
      |                                                         ~~~~~^
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/shim.attn_fwd.cc:5:
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/shim.attn_fwd.h:6:
/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include/aotriton/_internal/triton_kernel.h:23:62: error: no template named 'vector' in namespace 'std'
   23 |   hipError_t invoke(const char* kernel_name, dim3 grid, std::vector<void*>& args, hipStream_t stream);
      |                                                         ~~~~~^
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,True,True___MI200.cc:11:
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/../shim.attn_fwd.h:6:
/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include/aotriton/_internal/triton_kernel.h:23:62: error: no template named 'vector' in namespace 'std'
   23 |   hipError_t invoke(const char* kernel_name, dim3 grid, std::vector<void*>& args, hipStream_t stream);
      |                                                         ~~~~~^
1 error generated when compiling for gfx1010.
make: *** [Makefile.shim:19: flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,False,True___MI200.o] Error 1
make: *** Waiting for unfinished jobs....
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,False,False___MI200.cc:11:
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/../shim.attn_fwd.h:6:
/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include/aotriton/_internal/triton_kernel.h:23:62: error: no template named 'vector' in namespace 'std'
   23 |   hipError_t invoke(const char* kernel_name, dim3 grid, std::vector<void*>& args, hipStream_t stream);
      |                                                         ~~~~~^
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,False,True,True___MI200.cc:11:
In file included from /home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/build/v2src/flash/autotune.attn_fwd/../shim.attn_fwd.h:6:
/home/aaron/Projects/personal/PythonPlayground/pytorch/build/aotriton/src/include/aotriton/_internal/triton_kernel.h:23:62: error: no template named 'vector' in namespace 'std'
   23 |   hipError_t invoke(const char* kernel_name, dim3 grid, std::vector<void*>& args, hipStream_t stream);
      |                                                         ~~~~~^
1 error generated when compiling for gfx1010.
make: *** [Makefile.shim:15: flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,True,False___MI200.o] Error 1
1 error generated when compiling for gfx1010.
make: *** [Makefile.shim:11: flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,True,True___MI200.o] Error 1
1 error generated when compiling for gfx1010.
make: *** [Makefile.shim:23: flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,True,False,False___MI200.o] Error 1
1 error generated when compiling for gfx1010.
make: *** [Makefile.shim:27: flash/autotune.attn_fwd/FONLY__^fp16@16,1,16,False,True,True___MI200.o] Error 1
1 error generated when compiling for gfx1010.
make: *** [Makefile.shim:1290: flash/shim.attn_fwd.o] Error 1

This happens even when setting the USE_FLASH_ATTENTION option to OFF

Operating System

openSUSE Leap 15.5

CPU

AMD Ryzen 5 4600H

GPU

AMD Radeon Pro VII

ROCm Version

ROCm 6.0.0

ROCm Component

No response

Steps to Reproduce

try to build Pytorch with PYTORCH_ROCM_ARCH=gfx1010 USE_FLASH_ATTENTION=OFF USE_ROCM=ON ROCM_PATH=/opt/rocm python3 setup.py develop

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

ROCk module is loaded
=====================    
HSA System Attributes    
=====================    
Runtime Version:         1.13
Runtime Ext Version:     1.4
System Timestamp Freq.:  1000.000000MHz
Sig. Max Wait Duration:  18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count)
Machine Model:           LARGE                              
System Endianness:       LITTLE                             
Mwaitx:                  DISABLED
DMAbuf Support:          NO

==========               
HSA Agents               
==========               
*******                  
Agent 1                  
*******                  
  Name:                    AMD Ryzen 5 4600H with Radeon Graphics
  Uuid:                    CPU-XX                             
  Marketing Name:          AMD Ryzen 5 4600H with Radeon Graphics
  Vendor Name:             CPU                                
  Feature:                 None specified                     
  Profile:                 FULL_PROFILE                       
  Float Round Mode:        NEAR                               
  Max Queue Number:        0(0x0)                             
  Queue Min Size:          0(0x0)                             
  Queue Max Size:          0(0x0)                             
  Queue Type:              MULTI                              
  Node:                    0                                  
  Device Type:             CPU                                
  Cache Info:              
    L1:                      32768(0x8000) KB                   
  Chip ID:                 0(0x0)                             
  ASIC Revision:           0(0x0)                             
  Cacheline Size:          64(0x40)                           
  Max Clock Freq. (MHz):   3000                               
  BDFID:                   0                                  
  Internal Node ID:        0                                  
  Compute Unit:            12                                 
  SIMDs per CU:            0                                  
  Shader Engines:          0                                  
  Shader Arrs. per Eng.:   0                                  
  WatchPts on Addr. Ranges:1                                  
  Features:                None
  Pool Info:               
    Pool 1                   
      Segment:                 GLOBAL; FLAGS: FINE GRAINED        
      Size:                    48669032(0x2e6a168) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Recommended Granule:4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
    Pool 2                   
      Segment:                 GLOBAL; FLAGS: KERNARG, FINE GRAINED
      Size:                    48669032(0x2e6a168) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Recommended Granule:4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
    Pool 3                   
      Segment:                 GLOBAL; FLAGS: COARSE GRAINED      
      Size:                    48669032(0x2e6a168) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Recommended Granule:4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
  ISA Info:                
*******                  
Agent 2                  
*******                  
  Name:                    gfx1010                            
  Uuid:                    GPU-XX                             
  Marketing Name:          AMD Radeon RX 5600M                
  Vendor Name:             AMD                                
  Feature:                 KERNEL_DISPATCH                    
  Profile:                 BASE_PROFILE                       
  Float Round Mode:        NEAR                               
  Max Queue Number:        128(0x80)                          
  Queue Min Size:          64(0x40)                           
  Queue Max Size:          131072(0x20000)                    
  Queue Type:              MULTI                              
  Node:                    1                                  
  Device Type:             GPU                                
  Cache Info:              
    L1:                      16(0x10) KB                        
    L2:                      4096(0x1000) KB                    
  Chip ID:                 29471(0x731f)                      
  ASIC Revision:           2(0x2)                             
  Cacheline Size:          64(0x40)                           
  Max Clock Freq. (MHz):   1750                               
  BDFID:                   768                                
  Internal Node ID:        1                                  
  Compute Unit:            36                                 
  SIMDs per CU:            2                                  
  Shader Engines:          2                                  
  Shader Arrs. per Eng.:   2                                  
  WatchPts on Addr. Ranges:4                                  
  Coherent Host Access:    FALSE                              
  Features:                KERNEL_DISPATCH 
  Fast F16 Operation:      TRUE                               
  Wavefront Size:          32(0x20)                           
  Workgroup Max Size:      1024(0x400)                        
  Workgroup Max Size per Dimension:
    x                        1024(0x400)                        
    y                        1024(0x400)                        
    z                        1024(0x400)                        
  Max Waves Per CU:        40(0x28)                           
  Max Work-item Per CU:    1280(0x500)                        
  Grid Max Size:           4294967295(0xffffffff)             
  Grid Max Size per Dimension:
    x                        4294967295(0xffffffff)             
    y                        4294967295(0xffffffff)             
    z                        4294967295(0xffffffff)             
  Max fbarriers/Workgrp:   32                                 
  Packet Processor uCode:: 146                                
  SDMA engine uCode::      35                                 
  IOMMU Support::          None                               
  Pool Info:               
    Pool 1                   
      Segment:                 GLOBAL; FLAGS: COARSE GRAINED      
      Size:                    6275072(0x5fc000) KB               
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Recommended Granule:2048KB                             
      Alloc Alignment:         4KB                                
      Accessible by all:       FALSE                              
    Pool 2                   
      Segment:                 GLOBAL; FLAGS: EXTENDED FINE GRAINED
      Size:                    6275072(0x5fc000) KB               
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Recommended Granule:2048KB                             
      Alloc Alignment:         4KB                                
      Accessible by all:       FALSE                              
    Pool 3                   
      Segment:                 GROUP                              
      Size:                    64(0x40) KB                        
      Allocatable:             FALSE                              
      Alloc Granule:           0KB                                
      Alloc Recommended Granule:0KB                                
      Alloc Alignment:         0KB                                
      Accessible by all:       FALSE                              
  ISA Info:                
    ISA 1                    
      Name:                    amdgcn-amd-amdhsa--gfx1010:xnack-  
      Machine Models:          HSA_MACHINE_MODEL_LARGE            
      Profiles:                HSA_PROFILE_BASE                   
      Default Rounding Mode:   NEAR                               
      Default Rounding Mode:   NEAR                               
      Fast f16:                TRUE                               
      Workgroup Max Size:      1024(0x400)                        
      Workgroup Max Size per Dimension:
        x                        1024(0x400)                        
        y                        1024(0x400)                        
        z                        1024(0x400)                        
      Grid Max Size:           4294967295(0xffffffff)             
      Grid Max Size per Dimension:
        x                        4294967295(0xffffffff)             
        y                        4294967295(0xffffffff)             
        z                        4294967295(0xffffffff)             
      FBarrier Max Size:       32                                 
*******                  
Agent 3                  
*******                  
  Name:                    gfx90c                             
  Uuid:                    GPU-XX                             
  Marketing Name:                                             
  Vendor Name:             AMD                                
  Feature:                 KERNEL_DISPATCH                    
  Profile:                 BASE_PROFILE                       
  Float Round Mode:        NEAR                               
  Max Queue Number:        128(0x80)                          
  Queue Min Size:          64(0x40)                           
  Queue Max Size:          131072(0x20000)                    
  Queue Type:              MULTI                              
  Node:                    2                                  
  Device Type:             GPU                                
  Cache Info:              
    L1:                      16(0x10) KB                        
    L2:                      1024(0x400) KB                     
  Chip ID:                 5686(0x1636)                       
  ASIC Revision:           0(0x0)                             
  Cacheline Size:          64(0x40)                           
  Max Clock Freq. (MHz):   1500                               
  BDFID:                   2048                               
  Internal Node ID:        2                                  
  Compute Unit:            6                                  
  SIMDs per CU:            4                                  
  Shader Engines:          1                                  
  Shader Arrs. per Eng.:   1                                  
  WatchPts on Addr. Ranges:4                                  
  Coherent Host Access:    FALSE                              
  Features:                KERNEL_DISPATCH 
  Fast F16 Operation:      TRUE                               
  Wavefront Size:          64(0x40)                           
  Workgroup Max Size:      1024(0x400)                        
  Workgroup Max Size per Dimension:
    x                        1024(0x400)                        
    y                        1024(0x400)                        
    z                        1024(0x400)                        
  Max Waves Per CU:        40(0x28)                           
  Max Work-item Per CU:    2560(0xa00)                        
  Grid Max Size:           4294967295(0xffffffff)             
  Grid Max Size per Dimension:
    x                        4294967295(0xffffffff)             
    y                        4294967295(0xffffffff)             
    z                        4294967295(0xffffffff)             
  Max fbarriers/Workgrp:   32                                 
  Packet Processor uCode:: 468                                
  SDMA engine uCode::      40                                 
  IOMMU Support::          None                               
  Pool Info:               
    Pool 1                   
      Segment:                 GLOBAL; FLAGS: COARSE GRAINED      
      Size:                    524288(0x80000) KB                 
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Recommended Granule:2048KB                             
      Alloc Alignment:         4KB                                
      Accessible by all:       FALSE                              
    Pool 2                   
      Segment:                 GLOBAL; FLAGS: EXTENDED FINE GRAINED
      Size:                    524288(0x80000) KB                 
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Recommended Granule:2048KB                             
      Alloc Alignment:         4KB                                
      Accessible by all:       FALSE                              
    Pool 3                   
      Segment:                 GROUP                              
      Size:                    64(0x40) KB                        
      Allocatable:             FALSE                              
      Alloc Granule:           0KB                                
      Alloc Recommended Granule:0KB                                
      Alloc Alignment:         0KB                                
      Accessible by all:       FALSE                              
  ISA Info:                
    ISA 1                    
      Name:                    amdgcn-amd-amdhsa--gfx90c:xnack-   
      Machine Models:          HSA_MACHINE_MODEL_LARGE            
      Profiles:                HSA_PROFILE_BASE                   
      Default Rounding Mode:   NEAR                               
      Default Rounding Mode:   NEAR                               
      Fast f16:                TRUE                               
      Workgroup Max Size:      1024(0x400)                        
      Workgroup Max Size per Dimension:
        x                        1024(0x400)                        
        y                        1024(0x400)                        
        z                        1024(0x400)                        
      Grid Max Size:           4294967295(0xffffffff)             
      Grid Max Size per Dimension:
        x                        4294967295(0xffffffff)             
        y                        4294967295(0xffffffff)             
        z                        4294967295(0xffffffff)             
      FBarrier Max Size:       32                                 
*** Done ***             

Additional Information

No response

xinyazhang commented 4 months ago

This seems to be an #include related problem. Can you give https://github.com/ROCm/aotriton/commit/0873896ab690d5767975f2bb9ab850b1a103b26e a try?

xinyazhang commented 4 months ago

Another problem is Navi (aka RDNA) GPUs are not supported by this project yet. The only supported architectures are MI200/MI300 (gfx90a/gf942), aka CDNA 2/3 GPUs. See answers in #16

We are going to add Navi supports once the Triton compiler support it.

Zakhrov commented 4 months ago

From what I have found out, only gfx1100 supports the WMMA intrinsics, can we make aotriton respect the PYTORCH_ROCM_ARCH variable to skip compilation alltogether?

Zakhrov commented 4 months ago

Removing the if guard for AOTRITON_USE_ZSTD worked. But it takes a really long time to build the HIP kernels. I think a more elegant target handling solution (like the one I mentioned above) would help with reducing the build times, particularly when debugging. Also the HIP kernels were built with my native offload-arch (gfx1010) instead of with offload-arch=gfx90a or offload-arch=gfx942

Zakhrov commented 3 weeks ago

With Rocm 6.2, it fails to build with:

FAILED: v2src/libaotriton_v2.so 
: && hipcc -fPIC -O3 -DNDEBUG   -shared -Wl,-soname,libaotriton_v2.so -o v2src/libaotriton_v2.so @CMakeFiles/aotriton_v2.rsp  && :
ld.lld: error: /lib/libgcc_s.so.1 is incompatible with elf64-x86-64
ld.lld: error: /lib/libgcc_s.so.1 is incompatible with elf64-x86-64
clang++: error: linker command failed with exit code 1 (use -v to see invocation)
failed to execute:/opt/rocm-6.2.0/lib/llvm/bin/clang++ --driver-mode=g++ --hip-link  -fPIC -O3 -DNDEBUG -shared -Wl,-soname,libaotriton_v2.so -o "v2src/libaotriton_v2.so" \@CMakeFiles/aotriton_v2.rsp
ninja: build stopped: subcommand failed.
xinyazhang commented 3 weeks ago

You probably want to double check your compiler or system environment, /lib/libgcc_s.so.1 is a 32bit library and should not present on any modern system.

Zakhrov commented 3 weeks ago

I removed the 32bit version of libgcc_s and I now get this error:

FAILED: v2src/libaotriton_v2.so 
: && hipcc -fPIC -O3 -DNDEBUG   -shared -Wl,-soname,libaotriton_v2.so -o v2src/libaotriton_v2.so @CMakeFiles/aotriton_v2.rsp  && :
ld.lld: error: /usr/lib64/gcc/x86_64-suse-linux/13/libgcc_s.so:4: unable to find libgcc_s.so.1
>>> GROUP ( libgcc_s.so.1 -lgcc )
>>>         ^
clang++: error: linker command failed with exit code 1 (use -v to see invocation)
failed to execute:/opt/rocm-6.2.0/lib/llvm/bin/clang++ --driver-mode=g++ --hip-link  -fPIC -O3 -DNDEBUG -shared -Wl,-soname,libaotriton_v2.so -o "v2src/libaotriton_v2.so" \@CMakeFiles/aotriton_v2.rsp
ninja: build stopped: subcommand failed.

Also, overriding the compiler by using CC=clang CXX=clang++ doesn't work because clang complains about variable length arrays

Zakhrov commented 3 weeks ago

It looks like it is a problem with rocm's LLVM linker, which seems to not respect LD_LIBRARY_PATH, and it seems to not skip incompatible libraries.

xinyazhang commented 3 weeks ago

which seems to not respect LD_LIBRARY_PATH

This only has top priority as a runtime env var, for ld its precedence is after -rpath-link or -rpath options. See -rpath-link= section from https://man7.org/linux/man-pages/man1/ld.1.html for more details.

Also you may want to use container for your build system. The closest public available image is docker pull rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0