Open reachtarunhere opened 9 months ago
There is now support for flash-attention2 on AMD GPUs with PyTorch. They use the triton kernels for the same.
https://github.com/ROCmSoftwarePlatform/flash-attention
JAX-Triton currently doesn't work. On trying the add example I get the following error which I suspect is due to some CUDA specific things in the triton_lib.py
I can run other tests etc. that are requested here to help make progress on this.
(/jax_miniconda) Singularity> python add.py 2024-01-22 13:12:41.578159: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory /usr/share/libdrm/amdgpu.ids: No such file or directory 2024-01-22 13:12:46.635298: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error: INTERNAL: // // Generated by LLVM NVPTX Back-End // .version 8.2 .target sm_90a .address_size 64 // .globl add_kernel_0d1d2d .visible .entry add_kernel_0d1d2d( .param .u64 add_kernel_0d1d2d_param_0, .param .u64 add_kernel_0d1d2d_param_1, .param .u64 add_kernel_0d1d2d_param_2 ) .maxntid 128, 1, 1 { .reg .pred %p<5>; .reg .b32 %r<10>; .reg .b64 %rd<8>; .loc 1 24 0 $L__func_begin0: .loc 1 24 0 ld.param.u64 %rd4, [add_kernel_0d1d2d_param_0]; ld.param.u64 %rd5, [add_kernel_0d1d2d_param_1]; $L__tmp0: .loc 1 33 39 mov.u32 %r5, %tid.x; and.b32 %r6, %r5, 7; ld.param.u64 %rd6, [add_kernel_0d1d2d_param_2]; .loc 1 31 22 mov.u32 %r1, %ctaid.x; .loc 1 32 22 shl.b32 %r7, %r1, 3; .loc 1 33 26 or.b32 %r8, %r7, %r6; .loc 1 34 19 setp.lt.s32 %p1, %r8, 8; .loc 1 35 22 mul.wide.s32 %rd7, %r8, 4; add.s64 %rd1, %rd4, %rd7; .loc 1 35 14 mov.u32 %r2, 0x0; @%p1 ld.global.b32 { %r2 }, [ %rd1 + 0 ]; .loc 1 36 22 add.s64 %rd2, %rd5, %rd7; .loc 1 36 14 mov.u32 %r3, 0x0; @%p1 ld.global.b32 { %r3 }, [ %rd2 + 0 ]; .loc 1 37 15 add.s32 %r4, %r3, %r2; .loc 1 38 24 add.s64 %rd3, %rd6, %rd7; .loc 1 38 33 and.b32 %r9, %r5, 120; setp.eq.s32 %p4, %r9, 0; and.pred %p3, %p4, %p1; @%p3 st.global.b32 [ %rd3 + 0 ], { %r4 }; .loc 1 38 2 ret; $L__tmp1: $L__func_end0: } .file 1 "/jax_miniconda/add.py" .section .debug_abbrev { .b8 1 .b8 17 .b8 1 .b8 37 .b8 8 .b8 19 .b8 5 .b8 3 .b8 8 .b8 16 .b8 6 .b8 27 .b8 8 .b8 180 .b8 66 .b8 12 .b8 17 .b8 1 .b8 18 .b8 1 .b8 0 .b8 0 .b8 2 .b8 46 .b8 0 .b8 17 .b8 1 .b8 18 .b8 1 .b8 64 .b8 10 .b8 135 .b8 64 .b8 8 .b8 3 .b8 8 .b8 58 .b8 11 .b8 59 .b8 11 .b8 63 .b8 12 .b8 0 .b8 0 .b8 0 } .section .debug_info { .b32 119 .b8 2 .b8 0 .b32 .debug_abbrev .b8 8 .b8 1 .b8 116 .b8 114 .b8 105 .b8 116 .b8 111 .b8 110 .b8 0 .b8 2 .b8 0 .b8 97 .b8 100 .b8 100 .b8 46 .b8 112 .b8 121 .b8 0 .b32 .debug_line .b8 47 .b8 106 .b8 97 .b8 120 .b8 95 .b8 109 .b8 105 .b8 110 .b8 105 .b8 99 .b8 111 .b8 110 .b8 100 .b8 97 .b8 0 .b8 1 .b64 $L__func_begin0 .b64 $L__func_end0 .b8 2 .b64 $L__func_begin0 .b64 $L__func_end0 .b8 1 .b8 156 .b8 97 .b8 100 .b8 100 .b8 95 .b8 107 .b8 101 .b8 114 .b8 110 .b8 101 .b8 108 .b8 95 .b8 48 .b8 100 .b8 49 .b8 100 .b8 50 .b8 100 .b8 0 .b8 97 .b8 100 .b8 100 .b8 95 .b8 107 .b8 101 .b8 114 .b8 110 .b8 101 .b8 108 .b8 95 .b8 48 .b8 100 .b8 49 .b8 100 .b8 50 .b8 100 .b8 0 .b8 1 .b8 24 .b8 1 .b8 0 } .section .debug_pubnames { .b32 $L__pubNames_end0-$L__pubNames_start0 $L__pubNames_start0: .b8 2 .b8 0 .b32 .debug_info .b32 123 .b32 64 .b8 97 .b8 100 .b8 100 .b8 95 .b8 107 .b8 101 .b8 114 .b8 110 .b8 101 .b8 108 .b8 95 .b8 48 .b8 100 .b8 49 .b8 100 .b8 50 .b8 100 .b8 0 .b32 0 $L__pubNames_end0: } .section .debug_pubtypes { .b32 $L__pubTypes_end0-$L__pubTypes_start0 $L__pubTypes_start0: .b8 2 .b8 0 .b32 .debug_info .b32 123 .b32 0 $L__pubTypes_end0: } .section .debug_loc { } ; No such file or directory 2024-01-22 13:12:46.635704: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2716] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: // // Generated by LLVM NVPTX Back-End // .version 8.2 .target sm_90a .address_size 64 // .globl add_kernel_0d1d2d .visible .entry add_kernel_0d1d2d( .param .u64 add_kernel_0d1d2d_param_0, .param .u64 add_kernel_0d1d2d_param_1, .param .u64 add_kernel_0d1d2d_param_2 ) .maxntid 128, 1, 1 { .reg .pred %p<5>; .reg .b32 %r<10>; .reg .b64 %rd<8>; .loc 1 24 0 $L__func_begin0: .loc 1 24 0 ld.param.u64 %rd4, [add_kernel_0d1d2d_param_0]; ld.param.u64 %rd5, [add_kernel_0d1d2d_param_1]; $L__tmp0: .loc 1 33 39 mov.u32 %r5, %tid.x; and.b32 %r6, %r5, 7; ld.param.u64 %rd6, [add_kernel_0d1d2d_param_2]; .loc 1 31 22 mov.u32 %r1, %ctaid.x; .loc 1 32 22 shl.b32 %r7, %r1, 3; .loc 1 33 26 or.b32 %r8, %r7, %r6; .loc 1 34 19 setp.lt.s32 %p1, %r8, 8; .loc 1 35 22 mul.wide.s32 %rd7, %r8, 4; add.s64 %rd1, %rd4, %rd7; .loc 1 35 14 mov.u32 %r2, 0x0; @%p1 ld.global.b32 { %r2 }, [ %rd1 + 0 ]; .loc 1 36 22 add.s64 %rd2, %rd5, %rd7; .loc 1 36 14 mov.u32 %r3, 0x0; @%p1 ld.global.b32 { %r3 }, [ %rd2 + 0 ]; .loc 1 37 15 add.s32 %r4, %r3, %r2; .loc 1 38 24 add.s64 %rd3, %rd6, %rd7; .loc 1 38 33 and.b32 %r9, %r5, 120; setp.eq.s32 %p4, %r9, 0; and.pred %p3, %p4, %p1; @%p3 st.global.b32 [ %rd3 + 0 ], { %r4 }; .loc 1 38 2 ret; $L__tmp1: $L__func_end0: } .file 1 "/jax_miniconda/add.py" .section .debug_abbrev { .b8 1 .b8 17 .b8 1 .b8 37 .b8 8 .b8 19 .b8 5 .b8 3 .b8 8 .b8 16 .b8 6 .b8 27 .b8 8 .b8 180 .b8 66 .b8 12 .b8 17 .b8 1 .b8 18 .b8 1 .b8 0 .b8 0 .b8 2 .b8 46 .b8 0 .b8 17 .b8 1 .b8 18 .b8 1 .b8 64 .b8 10 .b8 135 .b8 64 .b8 8 .b8 3 .b8 8 .b8 58 .b8 11 .b8 59 .b8 11 .b8 63 .b8 12 .b8 0 .b8 0 .b8 0 } .section .debug_info { .b32 119 .b8 2 .b8 0 .b32 .debug_abbrev .b8 8 .b8 1 .b8 116 .b8 114 .b8 105 .b8 116 .b8 111 .b8 110 .b8 0 .b8 2 .b8 0 .b8 97 .b8 100 .b8 100 .b8 46 .b8 112 .b8 121 .b8 0 .b32 .debug_line .b8 47 .b8 106 .b8 97 .b8 120 .b8 95 .b8 109 .b8 105 .b8 110 .b8 105 .b8 99 .b8 111 .b8 110 .b8 100 .b8 97 .b8 0 .b8 1 .b64 $L__func_begin0 .b64 $L__func_end0 .b8 2 .b64 $L__func_begin0 .b64 $L__func_end0 .b8 1 .b8 156 .b8 97 .b8 100 .b8 100 .b8 95 .b8 107 .b8 101 .b8 114 .b8 110 .b8 101 .b8 108 .b8 95 .b8 48 .b8 100 .b8 49 .b8 100 .b8 50 .b8 100 .b8 0 .b8 97 .b8 100 .b8 100 .b8 95 .b8 107 .b8 101 .b8 114 .b8 110 .b8 101 .b8 108 .b8 95 .b8 48 .b8 100 .b8 49 .b8 100 .b8 50 .b8 100 .b8 0 .b8 1 .b8 24 .b8 1 .b8 0 } .section .debug_pubnames { .b32 $L__pubNames_end0-$L__pubNames_start0 $L__pubNames_start0: .b8 2 .b8 0 .b32 .debug_info .b32 123 .b32 64 .b8 97 .b8 100 .b8 100 .b8 95 .b8 107 .b8 101 .b8 114 .b8 110 .b8 101 .b8 108 .b8 95 .b8 48 .b8 100 .b8 49 .b8 100 .b8 50 .b8 100 .b8 0 .b32 0 $L__pubNames_end0: } .section .debug_pubtypes { .b32 $L__pubTypes_end0-$L__pubTypes_start0 $L__pubTypes_start0: .b8 2 .b8 0 .b32 .debug_info .b32 123 .b32 0 $L__pubTypes_end0: } .section .debug_loc { } ; No such file or directory; current tracing scope: custom-call.3; current profiling annotation: XlaModule:#prefix=jit(triton_kernel_call)/jit(main)/triton_kernel_call[fn=JITFunction(__main__:add_kernel) scalar_args=() name= custom_call_target_name=triton_kernel_call out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=int32),) grid=(1,) num_warps=None num_stages=None num_ctas=1 enable_fp_fusion=True enable_warp_specialization=False enable_persistent=False input_output_aliases=() zeroed_outputs=() debug=False serialized_metadata=b'' block_size=8],hlo_module=jit_triton_kernel_call,program_id=2#. jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/jax_miniconda/add.py", line 56, in <module> print(add(x_val, y_val)) File "/jax_miniconda/add.py", line 44, in add return jt.triton_call( File "/jax_miniconda/lib/python3.10/site-packages/jax_triton/triton_lib.py", line 681, in triton_call out_flat = triton_kernel_call_p.bind( File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/core.py", line 402, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/core.py", line 405, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/core.py", line 893, in process_primitive return primitive.impl(*tracers, **params) File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive outs = fun(*args) jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: // // Generated by LLVM NVPTX Back-End // .version 8.2 .target sm_90a .address_size 64 // .globl add_kernel_0d1d2d .visible .entry add_kernel_0d1d2d( .param .u64 add_kernel_0d1d2d_param_0, .param .u64 add_kernel_0d1d2d_param_1, .param .u64 add_kernel_0d1d2d_param_2 ) .maxntid 128, 1, 1 { .reg .pred %p<5>; .reg .b32 %r<10>; .reg .b64 %rd<8>; .loc 1 24 0 $L__func_begin0: .loc 1 24 0 ld.param.u64 %rd4, [add_kernel_0d1d2d_param_0]; ld.param.u64 %rd5, [add_kernel_0d1d2d_param_1]; $L__tmp0: .loc 1 33 39 mov.u32 %r5, %tid.x; and.b32 %r6, %r5, 7; ld.param.u64 %rd6, [add_kernel_0d1d2d_param_2]; .loc 1 31 22 mov.u32 %r1, %ctaid.x; .loc 1 32 22 shl.b32 %r7, %r1, 3; .loc 1 33 26 or.b32 %r8, %r7, %r6; .loc 1 34 19 setp.lt.s32 %p1, %r8, 8; .loc 1 35 22 mul.wide.s32 %rd7, %r8, 4; add.s64 %rd1, %rd4, %rd7; .loc 1 35 14 mov.u32 %r2, 0x0; @%p1 ld.global.b32 { %r2 }, [ %rd1 + 0 ]; .loc 1 36 22 add.s64 %rd2, %rd5, %rd7; .loc 1 36 14 mov.u32 %r3, 0x0; @%p1 ld.global.b32 { %r3 }, [ %rd2 + 0 ]; .loc 1 37 15 add.s32 %r4, %r3, %r2; .loc 1 38 24 add.s64 %rd3, %rd6, %rd7; .loc 1 38 33 and.b32 %r9, %r5, 120; setp.eq.s32 %p4, %r9, 0; and.pred %p3, %p4, %p1; @%p3 st.global.b32 [ %rd3 + 0 ], { %r4 }; .loc 1 38 2 ret; $L__tmp1: $L__func_end0: } .file 1 "/jax_miniconda/add.py" .section .debug_abbrev { .b8 1 .b8 17 .b8 1 .b8 37 .b8 8 .b8 19 .b8 5 .b8 3 .b8 8 .b8 16 .b8 6 .b8 27 .b8 8 .b8 180 .b8 66 .b8 12 .b8 17 .b8 1 .b8 18 .b8 1 .b8 0 .b8 0 .b8 2 .b8 46 .b8 0 .b8 17 .b8 1 .b8 18 .b8 1 .b8 64 .b8 10 .b8 135 .b8 64 .b8 8 .b8 3 .b8 8 .b8 58 .b8 11 .b8 59 .b8 11 .b8 63 .b8 12 .b8 0 .b8 0 .b8 0 } .section .debug_info { .b32 119 .b8 2 .b8 0 .b32 .debug_abbrev .b8 8 .b8 1 .b8 116 .b8 114 .b8 105 .b8 116 .b8 111 .b8 110 .b8 0 .b8 2 .b8 0 .b8 97 .b8 100 .b8 100 .b8 46 .b8 112 .b8 121 .b8 0 .b32 .debug_line .b8 47 .b8 106 .b8 97 .b8 120 .b8 95 .b8 109 .b8 105 .b8 110 .b8 105 .b8 99 .b8 111 .b8 110 .b8 100 .b8 97 .b8 0 .b8 1 .b64 $L__func_begin0 .b64 $L__func_end0 .b8 2 .b64 $L__func_begin0 .b64 $L__func_end0 .b8 1 .b8 156 .b8 97 .b8 100 .b8 100 .b8 95 .b8 107 .b8 101 .b8 114 .b8 110 .b8 101 .b8 108 .b8 95 .b8 48 .b8 100 .b8 49 .b8 100 .b8 50 .b8 100 .b8 0 .b8 97 .b8 100 .b8 100 .b8 95 .b8 107 .b8 101 .b8 114 .b8 110 .b8 101 .b8 108 .b8 95 .b8 48 .b8 100 .b8 49 .b8 100 .b8 50 .b8 100 .b8 0 .b8 1 .b8 24 .b8 1 .b8 0 } .section .debug_pubnames { .b32 $L__pubNames_end0-$L__pubNames_start0 $L__pubNames_start0: .b8 2 .b8 0 .b32 .debug_info .b32 123 .b32 64 .b8 97 .b8 100 .b8 100 .b8 95 .b8 107 .b8 101 .b8 114 .b8 110 .b8 101 .b8 108 .b8 95 .b8 48 .b8 100 .b8 49 .b8 100 .b8 50 .b8 100 .b8 0 .b32 0 $L__pubNames_end0: } .section .debug_pubtypes { .b32 $L__pubTypes_end0-$L__pubTypes_start0 $L__pubTypes_start0: .b8 2 .b8 0 .b32 .debug_info .b32 123 .b32 0 $L__pubTypes_end0: } .section .debug_loc { } ; No such file or directory; current tracing scope: custom-call.3; current profiling annotation: XlaModule:#prefix=jit(triton_kernel_call)/jit(main)/triton_kernel_call[fn=JITFunction(__main__:add_kernel) scalar_args=() name= custom_call_target_name=triton_kernel_call out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=int32),) grid=(1,) num_warps=None num_stages=None num_ctas=1 enable_fp_fusion=True enable_warp_specialization=False enable_persistent=False input_output_aliases=() zeroed_outputs=() debug=False serialized_metadata=b'' block_size=8],hlo_module=jit_triton_kernel_call,program_id=2#.
There is now support for flash-attention2 on AMD GPUs with PyTorch. They use the triton kernels for the same.
https://github.com/ROCmSoftwarePlatform/flash-attention
JAX-Triton currently doesn't work. On trying the add example I get the following error which I suspect is due to some CUDA specific things in the triton_lib.py
I can run other tests etc. that are requested here to help make progress on this.