google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.23k stars 2.68k forks source link

Support for AMD/ROCm with Pallas #19453

Open reachtarunhere opened 5 months ago

reachtarunhere commented 5 months ago

I have JAX and Triton installed. On trying the code below I get the following error:

I assume this is due to XLA treating my GPU to be a NVIDIA GPU?

import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops import attention

bs = 2
seqlen = 1000
n_heads = 32
dim = 128
rng = jax.random.PRNGKey(0)
xq = jax.random.normal(rng, (bs, seqlen, n_heads, dim))
xk = jax.random.normal(rng, (bs, seqlen, n_heads, dim))
xv = jax.random.normal(rng, (bs, seqlen, n_heads, dim))

print('reference')
res = attention.mha_reference(xq, xk, xv, None)
print(res)
print(res.shape)

print('real kernel')
print(attention.mha(xq, xk, xv, None))
>>> print(res.shape)
(2, 1000, 32, 128)
>>> print('real kernel')
real kernel
>>> print(attention.mha(xq, xk, xv, None))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/attention.py", line 216, in mha
    return pl.pallas_call(
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 410, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: AttributeError: module 'triton.compiler.compiler' has no attribute 'CudaTargetDescriptor'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1551, in pallas_call_lowering
    compilation_result = compile_jaxpr(
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1493, in compile_jaxpr
    lowering_result = lower_jaxpr_to_triton_module(
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 241, in lower_jaxpr_to_triton_module
    builder.target = tc.CudaTargetDescriptor(
AttributeError: module 'triton.compiler.compiler' has no attribute 'CudaTargetDescriptor'
hawkinsp commented 5 months ago

I don't think we've ever tried this on AMD GPUs. Contributions welcome!

@rahulbatra85 ?

rahulbatra85 commented 5 months ago

@hawkinsp @reachtarunhere Yes we are working on upstreaming support for Pallas along with Triton and Jax-Triton on AMD

reachtarunhere commented 5 months ago

@rahulbatra85 thanks I did see that you had a fork for jax-triton. Anything semi broken I can try right now? Also let me know if I can take up something particular and help. Thanks :)

rahulbatra85 commented 5 months ago

@reachtarunhere Please try this docker image for now. docker pull rocm/jax-build:rocm6.0.0-jax0.4.20-py3.10.0-jax_triton

reachtarunhere commented 5 months ago

Thanks will do. Are there any binaries available. I am on LUMI and usually the step of converting from docker to singularity messes Jax containers for me.

ttim commented 5 months ago

@rahulbatra85 I've tried above code with Jax image you provided and it failed with following error:

NotImplementedError: MLIR translation rule for primitive 'pallas_call' not found for platform cpu

Any tips on how to run it? Thanks!

rahulbatra85 commented 5 months ago

@ttim The above code runs totally fine for me in the container I provided. In your case, it's running on CPU.

By the way, AFAIK, pallas call will lower to cpu only if interpret mode is true https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/attention.py#L179C5-L179C14

ttim commented 5 months ago

@rahulbatra85 it seems like in my case jax doesn't recognize gpu:

> sudo docker run -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 8G docker.io/rocm/jax-build:rocm6.0.0-jax0.4.20-py3.10.0-jax_triton
> python
> import jax
> jax.devices()
[CpuDevice(id=0)]

What's a correct way to run the docker container so jax can recognize AMD GPU available? I used docker command line from pytorch rocm tutorial.

Thank you very much for helping!

rahulbatra85 commented 5 months ago

docker run -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 8G docker.io/rocm/jax-build:rocm6.0.0-jax0.4.20-py3.10.0-jax_triton

No issue with how you are running the docker.

What kind of AMD GPU do you have? Do these commands work ok for you? rocm-smi rocminfo

ttim commented 5 months ago

@rahulbatra85 2x 7900 XTX rocm-smi:

====================================== ROCm System Management Interface ======================================
================================================ Concise Info ================================================
Device  [Model : Revision]    Temp    Power  Partitions      SCLK     MCLK    Fan  Perf  PwrCap  VRAM%  GPU%  
        Name (20 chars)       (Edge)  (Avg)  (Mem, Compute)                                                   
==============================================================================================================
0       [0x471e : 0xc8]       28.0°C  79.0W  N/A, N/A        1484Mhz  456Mhz  0%   auto  303.0W    0%   25%   
        0x744c                                                                                                
1       [0x471e : 0xc8]       25.0°C  42.0W  N/A, N/A        1930Mhz  96Mhz   0%   auto  303.0W    0%   33%   
        0x744c                                                                                                
==============================================================================================================
============================================ End of ROCm SMI Log =============================================

rocminfo:

ROCk module is loaded
=====================    
HSA System Attributes    
=====================    
Runtime Version:         1.1
System Timestamp Freq.:  1000.000000MHz
Sig. Max Wait Duration:  18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count)
Machine Model:           LARGE                              
System Endianness:       LITTLE                             
Mwaitx:                  DISABLED
DMAbuf Support:          YES

==========               
HSA Agents               
==========               
*******                  
Agent 1                  
*******                  
  Name:                    AMD EPYC 7252 8-Core Processor     
  Uuid:                    CPU-XX                             
  Marketing Name:          AMD EPYC 7252 8-Core Processor     
  Vendor Name:             CPU                                
  Feature:                 None specified                     
  Profile:                 FULL_PROFILE                       
  Float Round Mode:        NEAR                               
  Max Queue Number:        0(0x0)                             
  Queue Min Size:          0(0x0)                             
  Queue Max Size:          0(0x0)                             
  Queue Type:              MULTI                              
  Node:                    0                                  
  Device Type:             CPU                                
  Cache Info:              
    L1:                      32768(0x8000) KB                   
  Chip ID:                 0(0x0)                             
  ASIC Revision:           0(0x0)                             
  Cacheline Size:          64(0x40)                           
  Max Clock Freq. (MHz):   3100                               
  BDFID:                   0                                  
  Internal Node ID:        0                                  
  Compute Unit:            16                                 
  SIMDs per CU:            0                                  
  Shader Engines:          0                                  
  Shader Arrs. per Eng.:   0                                  
  WatchPts on Addr. Ranges:1                                  
  Features:                None
  Pool Info:               
    Pool 1                   
      Segment:                 GLOBAL; FLAGS: FINE GRAINED        
      Size:                    131768212(0x7da9f94) KB            
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
    Pool 2                   
      Segment:                 GLOBAL; FLAGS: KERNARG, FINE GRAINED
      Size:                    131768212(0x7da9f94) KB            
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
    Pool 3                   
      Segment:                 GLOBAL; FLAGS: COARSE GRAINED      
      Size:                    131768212(0x7da9f94) KB            
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
  ISA Info:                
*******                  
Agent 2                  
*******                  
  Name:                    gfx1100                            
  Uuid:                    GPU-a18fbfbd757cdfd0               
  Marketing Name:          Radeon RX 7900 XTX                 
  Vendor Name:             AMD                                
  Feature:                 KERNEL_DISPATCH                    
  Profile:                 BASE_PROFILE                       
  Float Round Mode:        NEAR                               
  Max Queue Number:        128(0x80)                          
  Queue Min Size:          64(0x40)                           
  Queue Max Size:          131072(0x20000)                    
  Queue Type:              MULTI                              
  Node:                    1                                  
  Device Type:             GPU                                
  Cache Info:              
    L1:                      32(0x20) KB                        
    L2:                      6144(0x1800) KB                    
    L3:                      98304(0x18000) KB                  
  Chip ID:                 29772(0x744c)                      
  ASIC Revision:           0(0x0)                             
  Cacheline Size:          64(0x40)                           
  Max Clock Freq. (MHz):   2371                               
  BDFID:                   49920                              
  Internal Node ID:        1                                  
  Compute Unit:            96                                 
  SIMDs per CU:            2                                  
  Shader Engines:          6                                  
  Shader Arrs. per Eng.:   2                                  
  WatchPts on Addr. Ranges:4                                  
  Coherent Host Access:    FALSE                              
  Features:                KERNEL_DISPATCH 
  Fast F16 Operation:      TRUE                               
  Wavefront Size:          32(0x20)                           
  Workgroup Max Size:      1024(0x400)                        
  Workgroup Max Size per Dimension:
    x                        1024(0x400)                        
    y                        1024(0x400)                        
    z                        1024(0x400)                        
  Max Waves Per CU:        32(0x20)                           
  Max Work-item Per CU:    1024(0x400)                        
  Grid Max Size:           4294967295(0xffffffff)             
  Grid Max Size per Dimension:
    x                        4294967295(0xffffffff)             
    y                        4294967295(0xffffffff)             
    z                        4294967295(0xffffffff)             
  Max fbarriers/Workgrp:   32                                 
  Packet Processor uCode:: 550                                
  SDMA engine uCode::      19                                 
  IOMMU Support::          None                               
  Pool Info:               
    Pool 1                   
      Segment:                 GLOBAL; FLAGS: COARSE GRAINED      
      Size:                    25149440(0x17fc000) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       FALSE                              
    Pool 2                   
      Segment:                 GLOBAL; FLAGS: EXTENDED FINE GRAINED
      Size:                    25149440(0x17fc000) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       FALSE                              
    Pool 3                   
      Segment:                 GROUP                              
      Size:                    64(0x40) KB                        
      Allocatable:             FALSE                              
      Alloc Granule:           0KB                                
      Alloc Alignment:         0KB                                
      Accessible by all:       FALSE                              
  ISA Info:                
    ISA 1                    
      Name:                    amdgcn-amd-amdhsa--gfx1100         
      Machine Models:          HSA_MACHINE_MODEL_LARGE            
      Profiles:                HSA_PROFILE_BASE                   
      Default Rounding Mode:   NEAR                               
      Default Rounding Mode:   NEAR                               
      Fast f16:                TRUE                               
      Workgroup Max Size:      1024(0x400)                        
      Workgroup Max Size per Dimension:
        x                        1024(0x400)                        
        y                        1024(0x400)                        
        z                        1024(0x400)                        
      Grid Max Size:           4294967295(0xffffffff)             
      Grid Max Size per Dimension:
        x                        4294967295(0xffffffff)             
        y                        4294967295(0xffffffff)             
        z                        4294967295(0xffffffff)             
      FBarrier Max Size:       32                                 
*******                  
Agent 3                  
*******                  
  Name:                    gfx1100                            
  Uuid:                    GPU-d59bfe6e2d839dab               
  Marketing Name:          Radeon RX 7900 XTX                 
  Vendor Name:             AMD                                
  Feature:                 KERNEL_DISPATCH                    
  Profile:                 BASE_PROFILE                       
  Float Round Mode:        NEAR                               
  Max Queue Number:        128(0x80)                          
  Queue Min Size:          64(0x40)                           
  Queue Max Size:          131072(0x20000)                    
  Queue Type:              MULTI                              
  Node:                    2                                  
  Device Type:             GPU                                
  Cache Info:              
    L1:                      32(0x20) KB                        
    L2:                      6144(0x1800) KB                    
    L3:                      98304(0x18000) KB                  
  Chip ID:                 29772(0x744c)                      
  ASIC Revision:           0(0x0)                             
  Cacheline Size:          64(0x40)                           
  Max Clock Freq. (MHz):   2371                               
  BDFID:                   33536                              
  Internal Node ID:        2                                  
  Compute Unit:            96                                 
  SIMDs per CU:            2                                  
  Shader Engines:          6                                  
  Shader Arrs. per Eng.:   2                                  
  WatchPts on Addr. Ranges:4                                  
  Coherent Host Access:    FALSE                              
  Features:                KERNEL_DISPATCH 
  Fast F16 Operation:      TRUE                               
  Wavefront Size:          32(0x20)                           
  Workgroup Max Size:      1024(0x400)                        
  Workgroup Max Size per Dimension:
    x                        1024(0x400)                        
    y                        1024(0x400)                        
    z                        1024(0x400)                        
  Max Waves Per CU:        32(0x20)                           
  Max Work-item Per CU:    1024(0x400)                        
  Grid Max Size:           4294967295(0xffffffff)             
  Grid Max Size per Dimension:
    x                        4294967295(0xffffffff)             
    y                        4294967295(0xffffffff)             
    z                        4294967295(0xffffffff)             
  Max fbarriers/Workgrp:   32                                 
  Packet Processor uCode:: 550                                
  SDMA engine uCode::      19                                 
  IOMMU Support::          None                               
  Pool Info:               
    Pool 1                   
      Segment:                 GLOBAL; FLAGS: COARSE GRAINED      
      Size:                    25149440(0x17fc000) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       FALSE                              
    Pool 2                   
      Segment:                 GLOBAL; FLAGS: EXTENDED FINE GRAINED
      Size:                    25149440(0x17fc000) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       FALSE                              
    Pool 3                   
      Segment:                 GROUP                              
      Size:                    64(0x40) KB                        
      Allocatable:             FALSE                              
      Alloc Granule:           0KB                                
      Alloc Alignment:         0KB                                
      Accessible by all:       FALSE                              
  ISA Info:                
    ISA 1                    
      Name:                    amdgcn-amd-amdhsa--gfx1100         
      Machine Models:          HSA_MACHINE_MODEL_LARGE            
      Profiles:                HSA_PROFILE_BASE                   
      Default Rounding Mode:   NEAR                               
      Default Rounding Mode:   NEAR                               
      Fast f16:                TRUE                               
      Workgroup Max Size:      1024(0x400)                        
      Workgroup Max Size per Dimension:
        x                        1024(0x400)                        
        y                        1024(0x400)                        
        z                        1024(0x400)                        
      Grid Max Size:           4294967295(0xffffffff)             
      Grid Max Size per Dimension:
        x                        4294967295(0xffffffff)             
        y                        4294967295(0xffffffff)             
        z                        4294967295(0xffffffff)             
      FBarrier Max Size:       32                                 
*** Done ***        
rahulbatra85 commented 5 months ago

@ttim Unfortunately, JAX is currently not supported on Navi i.e. 7900 XTX. We are working on adding support for this platform and hopefully will have it supported soon.

Thanks!

reachtarunhere commented 5 months ago

@rahulbatra85 I couldn't get it work because our HPC has issues detecting GPUs inside singularity containers with JAX. Can you please share the exact branches/versions of the libs below so I can try building myself? Can you also point out if there is a minimum rocm version requirement for this to work?

jaxlib jax jax-triton triton

rahulbatra85 commented 5 months ago

@rahulbatra85 I couldn't get it work because our HPC has issues detecting GPUs inside singularity containers with JAX. Can you please share the exact branches/versions of the libs below so I can try building myself? Can you also point out if there is a minimum rocm version requirement for this to work?

jaxlib jax jax-triton triton

@reachtarunhere https://github.com/ROCm/jax/tree/rocm-jaxlib-v0.4.20-rocm6.0-jax-triton https://github.com/ROCm/xla/tree/rocm-jaxlib-v0.4.20-rocm6.0 https://github.com/ROCm/triton/tree/triton-jax-triton https://github.com/rahulbatra85/jax-triton/tree/jax-triton-rocm

ROCm 5.7 or ROCm 6.0

reachtarunhere commented 5 months ago

@rahulbatra85 thank you! I tried the setup proposed above and I run into this error on the line when I run MHA as shown at the start of the issue

Any ideas what might be causing this? I am on a MI250x

Traceback (most recent call last):
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 332, in lower_jaxpr_to_triton_ir
    outvals = rule(rule_ctx, *invals, **eqn.params)
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 900, in _dot_general_lowering
    return tl.dot(
  File "/recommended_jax/jax_build_venv2/lib/python3.10/site-packages/triton/language/core.py", line 31, in wrapper
    return fn(*args, **kwargs)
  File "/recommended_jax/jax_build_venv2/lib/python3.10/site-packages/triton/language/core.py", line 971, in dot
    return semantic.dot(input, other, allow_tf32, out_dtype, _builder)
  File "/recommended_jax/jax_build_venv2/lib/python3.10/site-packages/triton/language/semantic.py", line 1264, in dot
    assert lhs.dtype == rhs.dtype, f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
AssertionError: First input (bf16) and second input (fp32) must have the same dtype!

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/experimental/pallas/ops/attention.py", line 211, in mha
    return pl.pallas_call(
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/pallas_call.py", line 383, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax._src.pallas.triton.lowering.TritonLoweringException: Exception while lowering eqn:
  a:f32[128,128] = dot_general[
  dimension_numbers=(([1], [0]), ([], []))
  preferred_element_type=float32
] b c
With context:
  TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x153d4b7d4f30>, builder=<triton._C.libtriton.triton.ir.builder object at 0x153d4b7845e0>, module=<triton._C.libtriton.triton.ir.module object at 0x153d4b784630>, grid_mapping=GridMapping(grid=(8, 2, 32), block_mappings=(BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }, memory_space=None)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0), program_ids=[<triton.language.core.tensor object at 0x153e4b92ded0>, <triton.language.core.tensor object at 0x153e4b92ea70>, <triton.language.core.tensor object at 0x153e4b92f4c0>]), avals_in=[ShapedArray(bfloat16[128,64]), ShapedArray(float32[64,128])], avals_out=[ShapedArray(float32[128,128])], block_infos=[None, None])
With inval shapes=[[constexpr[128], constexpr[64]], [constexpr[64], constexpr[128]]]
With inval types=[<[128, 64], bf16>, <[64, 128], fp32>]
In jaxpr:
{ lambda ; a:Ref{float32[1000,128]} b:f32[128,128] c:Ref{float32[1000,128]} d:i32[]
    e:f32[128,128] f:f32[128] g:f32[128]. let
    h:i32[] = mul d 64
    i:f32[64,128] <- a[h:h+64,:]
    j:f32[128,64] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 64)] 0.0
    k:f32[128,64] = transpose[permutation=(1, 0)] i
    l:f32[128,64] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] b k
    m:f32[128,64] = add j l
    n:f32[128] = reduce_max[axes=(1,)] m
    o:f32[128] = max f n
    p:f32[128] = sub f o
    q:f32[128] = exp2 p
    r:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] o
    s:f32[128,64] = sub m r
    t:f32[128,64] = exp2 s
    u:f32[128] = mul g q
    v:f32[128] = reduce_sum[axes=(1,)] t
    w:f32[128] = add u v
    x:f32[128] = mul g 0.0
    y:f32[128] = add x q
    z:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] y
    ba:f32[128,128] = mul e z
    bb:i32[] = mul d 64
    bc:f32[64,128] <- c[bb:bb+64,:]
    bd:bf16[128,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] t
    be:f32[128,128] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] bd bc
    bf:f32[128,128] = add ba be
  in (bf, o, w) }

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1541, in pallas_call_lowering
    compilation_result = compile_jaxpr(
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1483, in compile_jaxpr
    lowering_result = lower_jaxpr_to_triton_module(
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 277, in lower_jaxpr_to_triton_module
    () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 332, in lower_jaxpr_to_triton_ir
    outvals = rule(rule_ctx, *invals, **eqn.params)
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1240, in _scan_lowering_rule
    for_out = _lower_jaxpr_to_for_loop(
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1190, in _lower_jaxpr_to_for_loop
    all_out = lower_jaxpr_to_triton_ir(
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 336, in lower_jaxpr_to_triton_ir
    raise TritonLoweringException(
jax._src.pallas.triton.lowering.TritonLoweringException: Exception while lowering eqn:
  a:f32[128,128] = dot_general[
  dimension_numbers=(([1], [0]), ([], []))
  preferred_element_type=float32
] b c
With context:
  TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x153d4b7d4f30>, builder=<triton._C.libtriton.triton.ir.builder object at 0x153d4b7845e0>, module=<triton._C.libtriton.triton.ir.module object at 0x153d4b784630>, grid_mapping=GridMapping(grid=(8, 2, 32), block_mappings=(BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }, memory_space=None)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0), program_ids=[<triton.language.core.tensor object at 0x153e4b92ded0>, <triton.language.core.tensor object at 0x153e4b92ea70>, <triton.language.core.tensor object at 0x153e4b92f4c0>]), avals_in=[ShapedArray(bfloat16[128,64]), ShapedArray(float32[64,128])], avals_out=[ShapedArray(float32[128,128])], block_infos=[None, None])
With inval shapes=[[constexpr[128], constexpr[64]], [constexpr[64], constexpr[128]]]
With inval types=[<[128, 64], bf16>, <[64, 128], fp32>]
In jaxpr:
{ lambda ; a:Ref{float32[1000,128]} b:f32[128,128] c:Ref{float32[1000,128]} d:i32[]
    e:f32[128,128] f:f32[128] g:f32[128]. let
    h:i32[] = mul d 64
    i:f32[64,128] <- a[h:h+64,:]
    j:f32[128,64] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 64)] 0.0
    k:f32[128,64] = transpose[permutation=(1, 0)] i
    l:f32[128,64] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] b k
    m:f32[128,64] = add j l
    n:f32[128] = reduce_max[axes=(1,)] m
    o:f32[128] = max f n
    p:f32[128] = sub f o
    q:f32[128] = exp2 p
    r:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] o
    s:f32[128,64] = sub m r
    t:f32[128,64] = exp2 s
    u:f32[128] = mul g q
    v:f32[128] = reduce_sum[axes=(1,)] t
    w:f32[128] = add u v
    x:f32[128] = mul g 0.0
    y:f32[128] = add x q
    z:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] y
    ba:f32[128,128] = mul e z
    bb:i32[] = mul d 64
    bc:f32[64,128] <- c[bb:bb+64,:]
    bd:bf16[128,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] t
    be:f32[128,128] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] bd bc
    bf:f32[128,128] = add ba be
  in (bf, o, w) }
reachtarunhere commented 5 months ago

Further on making all the inputs bfloat16 I run into a different error:

>>> xq = jax.random.normal(rng, (bs, seqlen, n_heads, dim), dtype="bfloat16")
>>> print(attention.mha(xq, xq, xq, None))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/experimental/pallas/ops/attention.py", line 211, in mha
    return pl.pallas_call(
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/pallas_call.py", line 383, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: ttir_to_ttgir() got an unexpected keyword argument 'num_ctas'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1541, in pallas_call_lowering
    compilation_result = compile_jaxpr(
  File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1488, in compile_jaxpr
    ptx, name, shared_mem_bytes, compute_capability = compile_ttir_to_ptx_inplace(
  File "/recommended_jax/jax-triton-jax-triton-rocm/jax_triton/triton_lib.py", line 256, in compile_ttir_to_ptx_inplace
    ttgir = tc.ttir_to_ttgir(ttir, num_warps, warpsize=64, num_ctas=num_ctas, target=arch_full_details)
TypeError: ttir_to_ttgir() got an unexpected keyword argument 'num_ctas'
reachtarunhere commented 4 months ago

Bump when is this expected?

rahulbatra85 commented 4 months ago

@reachtarunhere Can you try these branches instead? https://github.com/ROCm/jax/tree/rocm-jaxlib-v0.4.24-jax-triton https://github.com/ROCm/xla/commits/rocm-jaxlib-v0.4.24-jax-triton/ https://github.com/rahulbatra85/jax-triton/tree/jax-triton-rocm-0.4.24 https://github.com/ROCm/triton/tree/jax-triton-rocm-0.4.24

reachtarunhere commented 4 months ago

@rahulbatra85 thanks for the update. I will try and get back to you. On our main server (lumi) we are unfortunately stuck with rocm 5.6.1. Is there any way to build these for for <rocm5.7

Thank you!