Open reachtarunhere opened 5 months ago
I don't think we've ever tried this on AMD GPUs. Contributions welcome!
@rahulbatra85 ?
@hawkinsp @reachtarunhere Yes we are working on upstreaming support for Pallas along with Triton and Jax-Triton on AMD
@rahulbatra85 thanks I did see that you had a fork for jax-triton. Anything semi broken I can try right now? Also let me know if I can take up something particular and help. Thanks :)
@reachtarunhere Please try this docker image for now. docker pull rocm/jax-build:rocm6.0.0-jax0.4.20-py3.10.0-jax_triton
Thanks will do. Are there any binaries available. I am on LUMI and usually the step of converting from docker to singularity messes Jax containers for me.
@rahulbatra85 I've tried above code with Jax image you provided and it failed with following error:
NotImplementedError: MLIR translation rule for primitive 'pallas_call' not found for platform cpu
Any tips on how to run it? Thanks!
@ttim The above code runs totally fine for me in the container I provided. In your case, it's running on CPU.
By the way, AFAIK, pallas call will lower to cpu only if interpret mode is true https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/attention.py#L179C5-L179C14
@rahulbatra85 it seems like in my case jax doesn't recognize gpu:
> sudo docker run -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 8G docker.io/rocm/jax-build:rocm6.0.0-jax0.4.20-py3.10.0-jax_triton
> python
> import jax
> jax.devices()
[CpuDevice(id=0)]
What's a correct way to run the docker container so jax can recognize AMD GPU available? I used docker command line from pytorch rocm tutorial.
Thank you very much for helping!
docker run -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 8G docker.io/rocm/jax-build:rocm6.0.0-jax0.4.20-py3.10.0-jax_triton
No issue with how you are running the docker.
What kind of AMD GPU do you have? Do these commands work ok for you?
rocm-smi
rocminfo
@rahulbatra85 2x 7900 XTX
rocm-smi
:
====================================== ROCm System Management Interface ======================================
================================================ Concise Info ================================================
Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
Name (20 chars) (Edge) (Avg) (Mem, Compute)
==============================================================================================================
0 [0x471e : 0xc8] 28.0°C 79.0W N/A, N/A 1484Mhz 456Mhz 0% auto 303.0W 0% 25%
0x744c
1 [0x471e : 0xc8] 25.0°C 42.0W N/A, N/A 1930Mhz 96Mhz 0% auto 303.0W 0% 33%
0x744c
==============================================================================================================
============================================ End of ROCm SMI Log =============================================
rocminfo
:
ROCk module is loaded
=====================
HSA System Attributes
=====================
Runtime Version: 1.1
System Timestamp Freq.: 1000.000000MHz
Sig. Max Wait Duration: 18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count)
Machine Model: LARGE
System Endianness: LITTLE
Mwaitx: DISABLED
DMAbuf Support: YES
==========
HSA Agents
==========
*******
Agent 1
*******
Name: AMD EPYC 7252 8-Core Processor
Uuid: CPU-XX
Marketing Name: AMD EPYC 7252 8-Core Processor
Vendor Name: CPU
Feature: None specified
Profile: FULL_PROFILE
Float Round Mode: NEAR
Max Queue Number: 0(0x0)
Queue Min Size: 0(0x0)
Queue Max Size: 0(0x0)
Queue Type: MULTI
Node: 0
Device Type: CPU
Cache Info:
L1: 32768(0x8000) KB
Chip ID: 0(0x0)
ASIC Revision: 0(0x0)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 3100
BDFID: 0
Internal Node ID: 0
Compute Unit: 16
SIMDs per CU: 0
Shader Engines: 0
Shader Arrs. per Eng.: 0
WatchPts on Addr. Ranges:1
Features: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: FINE GRAINED
Size: 131768212(0x7da9f94) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 2
Segment: GLOBAL; FLAGS: KERNARG, FINE GRAINED
Size: 131768212(0x7da9f94) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 3
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 131768212(0x7da9f94) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
ISA Info:
*******
Agent 2
*******
Name: gfx1100
Uuid: GPU-a18fbfbd757cdfd0
Marketing Name: Radeon RX 7900 XTX
Vendor Name: AMD
Feature: KERNEL_DISPATCH
Profile: BASE_PROFILE
Float Round Mode: NEAR
Max Queue Number: 128(0x80)
Queue Min Size: 64(0x40)
Queue Max Size: 131072(0x20000)
Queue Type: MULTI
Node: 1
Device Type: GPU
Cache Info:
L1: 32(0x20) KB
L2: 6144(0x1800) KB
L3: 98304(0x18000) KB
Chip ID: 29772(0x744c)
ASIC Revision: 0(0x0)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 2371
BDFID: 49920
Internal Node ID: 1
Compute Unit: 96
SIMDs per CU: 2
Shader Engines: 6
Shader Arrs. per Eng.: 2
WatchPts on Addr. Ranges:4
Coherent Host Access: FALSE
Features: KERNEL_DISPATCH
Fast F16 Operation: TRUE
Wavefront Size: 32(0x20)
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Max Waves Per CU: 32(0x20)
Max Work-item Per CU: 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 4294967295(0xffffffff)
y 4294967295(0xffffffff)
z 4294967295(0xffffffff)
Max fbarriers/Workgrp: 32
Packet Processor uCode:: 550
SDMA engine uCode:: 19
IOMMU Support:: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 25149440(0x17fc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 25149440(0x17fc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 3
Segment: GROUP
Size: 64(0x40) KB
Allocatable: FALSE
Alloc Granule: 0KB
Alloc Alignment: 0KB
Accessible by all: FALSE
ISA Info:
ISA 1
Name: amdgcn-amd-amdhsa--gfx1100
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 4294967295(0xffffffff)
y 4294967295(0xffffffff)
z 4294967295(0xffffffff)
FBarrier Max Size: 32
*******
Agent 3
*******
Name: gfx1100
Uuid: GPU-d59bfe6e2d839dab
Marketing Name: Radeon RX 7900 XTX
Vendor Name: AMD
Feature: KERNEL_DISPATCH
Profile: BASE_PROFILE
Float Round Mode: NEAR
Max Queue Number: 128(0x80)
Queue Min Size: 64(0x40)
Queue Max Size: 131072(0x20000)
Queue Type: MULTI
Node: 2
Device Type: GPU
Cache Info:
L1: 32(0x20) KB
L2: 6144(0x1800) KB
L3: 98304(0x18000) KB
Chip ID: 29772(0x744c)
ASIC Revision: 0(0x0)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 2371
BDFID: 33536
Internal Node ID: 2
Compute Unit: 96
SIMDs per CU: 2
Shader Engines: 6
Shader Arrs. per Eng.: 2
WatchPts on Addr. Ranges:4
Coherent Host Access: FALSE
Features: KERNEL_DISPATCH
Fast F16 Operation: TRUE
Wavefront Size: 32(0x20)
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Max Waves Per CU: 32(0x20)
Max Work-item Per CU: 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 4294967295(0xffffffff)
y 4294967295(0xffffffff)
z 4294967295(0xffffffff)
Max fbarriers/Workgrp: 32
Packet Processor uCode:: 550
SDMA engine uCode:: 19
IOMMU Support:: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 25149440(0x17fc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 25149440(0x17fc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 3
Segment: GROUP
Size: 64(0x40) KB
Allocatable: FALSE
Alloc Granule: 0KB
Alloc Alignment: 0KB
Accessible by all: FALSE
ISA Info:
ISA 1
Name: amdgcn-amd-amdhsa--gfx1100
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 4294967295(0xffffffff)
y 4294967295(0xffffffff)
z 4294967295(0xffffffff)
FBarrier Max Size: 32
*** Done ***
@ttim Unfortunately, JAX is currently not supported on Navi i.e. 7900 XTX. We are working on adding support for this platform and hopefully will have it supported soon.
Thanks!
@rahulbatra85 I couldn't get it work because our HPC has issues detecting GPUs inside singularity containers with JAX. Can you please share the exact branches/versions of the libs below so I can try building myself? Can you also point out if there is a minimum rocm version requirement for this to work?
jaxlib jax jax-triton triton
@rahulbatra85 I couldn't get it work because our HPC has issues detecting GPUs inside singularity containers with JAX. Can you please share the exact branches/versions of the libs below so I can try building myself? Can you also point out if there is a minimum rocm version requirement for this to work?
jaxlib jax jax-triton triton
@reachtarunhere https://github.com/ROCm/jax/tree/rocm-jaxlib-v0.4.20-rocm6.0-jax-triton https://github.com/ROCm/xla/tree/rocm-jaxlib-v0.4.20-rocm6.0 https://github.com/ROCm/triton/tree/triton-jax-triton https://github.com/rahulbatra85/jax-triton/tree/jax-triton-rocm
ROCm 5.7 or ROCm 6.0
@rahulbatra85 thank you! I tried the setup proposed above and I run into this error on the line when I run MHA as shown at the start of the issue
Any ideas what might be causing this? I am on a MI250x
Traceback (most recent call last):
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 332, in lower_jaxpr_to_triton_ir
outvals = rule(rule_ctx, *invals, **eqn.params)
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 900, in _dot_general_lowering
return tl.dot(
File "/recommended_jax/jax_build_venv2/lib/python3.10/site-packages/triton/language/core.py", line 31, in wrapper
return fn(*args, **kwargs)
File "/recommended_jax/jax_build_venv2/lib/python3.10/site-packages/triton/language/core.py", line 971, in dot
return semantic.dot(input, other, allow_tf32, out_dtype, _builder)
File "/recommended_jax/jax_build_venv2/lib/python3.10/site-packages/triton/language/semantic.py", line 1264, in dot
assert lhs.dtype == rhs.dtype, f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
AssertionError: First input (bf16) and second input (fp32) must have the same dtype!
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/experimental/pallas/ops/attention.py", line 211, in mha
return pl.pallas_call(
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/pallas_call.py", line 383, in wrapped
out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax._src.pallas.triton.lowering.TritonLoweringException: Exception while lowering eqn:
a:f32[128,128] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] b c
With context:
TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x153d4b7d4f30>, builder=<triton._C.libtriton.triton.ir.builder object at 0x153d4b7845e0>, module=<triton._C.libtriton.triton.ir.module object at 0x153d4b784630>, grid_mapping=GridMapping(grid=(8, 2, 32), block_mappings=(BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }, memory_space=None)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0), program_ids=[<triton.language.core.tensor object at 0x153e4b92ded0>, <triton.language.core.tensor object at 0x153e4b92ea70>, <triton.language.core.tensor object at 0x153e4b92f4c0>]), avals_in=[ShapedArray(bfloat16[128,64]), ShapedArray(float32[64,128])], avals_out=[ShapedArray(float32[128,128])], block_infos=[None, None])
With inval shapes=[[constexpr[128], constexpr[64]], [constexpr[64], constexpr[128]]]
With inval types=[<[128, 64], bf16>, <[64, 128], fp32>]
In jaxpr:
{ lambda ; a:Ref{float32[1000,128]} b:f32[128,128] c:Ref{float32[1000,128]} d:i32[]
e:f32[128,128] f:f32[128] g:f32[128]. let
h:i32[] = mul d 64
i:f32[64,128] <- a[h:h+64,:]
j:f32[128,64] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 64)] 0.0
k:f32[128,64] = transpose[permutation=(1, 0)] i
l:f32[128,64] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] b k
m:f32[128,64] = add j l
n:f32[128] = reduce_max[axes=(1,)] m
o:f32[128] = max f n
p:f32[128] = sub f o
q:f32[128] = exp2 p
r:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] o
s:f32[128,64] = sub m r
t:f32[128,64] = exp2 s
u:f32[128] = mul g q
v:f32[128] = reduce_sum[axes=(1,)] t
w:f32[128] = add u v
x:f32[128] = mul g 0.0
y:f32[128] = add x q
z:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] y
ba:f32[128,128] = mul e z
bb:i32[] = mul d 64
bc:f32[64,128] <- c[bb:bb+64,:]
bd:bf16[128,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] t
be:f32[128,128] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] bd bc
bf:f32[128,128] = add ba be
in (bf, o, w) }
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1541, in pallas_call_lowering
compilation_result = compile_jaxpr(
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1483, in compile_jaxpr
lowering_result = lower_jaxpr_to_triton_module(
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 277, in lower_jaxpr_to_triton_module
() = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 332, in lower_jaxpr_to_triton_ir
outvals = rule(rule_ctx, *invals, **eqn.params)
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1240, in _scan_lowering_rule
for_out = _lower_jaxpr_to_for_loop(
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1190, in _lower_jaxpr_to_for_loop
all_out = lower_jaxpr_to_triton_ir(
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 336, in lower_jaxpr_to_triton_ir
raise TritonLoweringException(
jax._src.pallas.triton.lowering.TritonLoweringException: Exception while lowering eqn:
a:f32[128,128] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] b c
With context:
TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x153d4b7d4f30>, builder=<triton._C.libtriton.triton.ir.builder object at 0x153d4b7845e0>, module=<triton._C.libtriton.triton.ir.module object at 0x153d4b784630>, grid_mapping=GridMapping(grid=(8, 2, 32), block_mappings=(BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }, memory_space=None), BlockMapping(block_shape=(<jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 1000, <jax._src.pallas.core.Mapped object at 0x154253e6e0e0>, 128), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }, memory_space=None)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0), program_ids=[<triton.language.core.tensor object at 0x153e4b92ded0>, <triton.language.core.tensor object at 0x153e4b92ea70>, <triton.language.core.tensor object at 0x153e4b92f4c0>]), avals_in=[ShapedArray(bfloat16[128,64]), ShapedArray(float32[64,128])], avals_out=[ShapedArray(float32[128,128])], block_infos=[None, None])
With inval shapes=[[constexpr[128], constexpr[64]], [constexpr[64], constexpr[128]]]
With inval types=[<[128, 64], bf16>, <[64, 128], fp32>]
In jaxpr:
{ lambda ; a:Ref{float32[1000,128]} b:f32[128,128] c:Ref{float32[1000,128]} d:i32[]
e:f32[128,128] f:f32[128] g:f32[128]. let
h:i32[] = mul d 64
i:f32[64,128] <- a[h:h+64,:]
j:f32[128,64] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 64)] 0.0
k:f32[128,64] = transpose[permutation=(1, 0)] i
l:f32[128,64] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] b k
m:f32[128,64] = add j l
n:f32[128] = reduce_max[axes=(1,)] m
o:f32[128] = max f n
p:f32[128] = sub f o
q:f32[128] = exp2 p
r:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] o
s:f32[128,64] = sub m r
t:f32[128,64] = exp2 s
u:f32[128] = mul g q
v:f32[128] = reduce_sum[axes=(1,)] t
w:f32[128] = add u v
x:f32[128] = mul g 0.0
y:f32[128] = add x q
z:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] y
ba:f32[128,128] = mul e z
bb:i32[] = mul d 64
bc:f32[64,128] <- c[bb:bb+64,:]
bd:bf16[128,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] t
be:f32[128,128] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] bd bc
bf:f32[128,128] = add ba be
in (bf, o, w) }
Further on making all the inputs bfloat16 I run into a different error:
>>> xq = jax.random.normal(rng, (bs, seqlen, n_heads, dim), dtype="bfloat16")
>>> print(attention.mha(xq, xq, xq, None))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/experimental/pallas/ops/attention.py", line 211, in mha
return pl.pallas_call(
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/pallas_call.py", line 383, in wrapped
out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: ttir_to_ttgir() got an unexpected keyword argument 'num_ctas'
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1541, in pallas_call_lowering
compilation_result = compile_jaxpr(
File "/recommended_jax/jax-rocm-jaxlib-v0.4.20-rocm6.0-jax-triton/jax/_src/pallas/triton/lowering.py", line 1488, in compile_jaxpr
ptx, name, shared_mem_bytes, compute_capability = compile_ttir_to_ptx_inplace(
File "/recommended_jax/jax-triton-jax-triton-rocm/jax_triton/triton_lib.py", line 256, in compile_ttir_to_ptx_inplace
ttgir = tc.ttir_to_ttgir(ttir, num_warps, warpsize=64, num_ctas=num_ctas, target=arch_full_details)
TypeError: ttir_to_ttgir() got an unexpected keyword argument 'num_ctas'
Bump when is this expected?
@reachtarunhere Can you try these branches instead? https://github.com/ROCm/jax/tree/rocm-jaxlib-v0.4.24-jax-triton https://github.com/ROCm/xla/commits/rocm-jaxlib-v0.4.24-jax-triton/ https://github.com/rahulbatra85/jax-triton/tree/jax-triton-rocm-0.4.24 https://github.com/ROCm/triton/tree/jax-triton-rocm-0.4.24
@rahulbatra85 thanks for the update. I will try and get back to you. On our main server (lumi) we are unfortunately stuck with rocm 5.6.1. Is there any way to build these for for <rocm5.7
Thank you!
I have JAX and Triton installed. On trying the code below I get the following error:
I assume this is due to XLA treating my GPU to be a NVIDIA GPU?