google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.04k stars 2.66k forks source link

ROCm 6.1, 7900 xtx: bfloat16 support not enabled? #21074

Open brettkoonce opened 1 month ago

brettkoonce commented 1 month ago

Description

@rahulbatra85 I managed to get jax to train resnet50 using the scenic library and the latest docker images, thanks to #18747.

When I change my code from float32 to float16, I get an approximate ~2x speedup. However, when I enable bfloat16, the system runs at the float32 speed. Am I correct in assuming that the docker + XTX + CDNA3 + ROCM 6.1 system is not generating bfloat16 code yet?

System info (python version, jaxlib version, accelerator, etc.)

ubuntu 22.04 lts rocm 6.1 7900 xtx rocm/jax:latest image

jax: 0.4.26 jaxlib: 0.4.26 numpy: 1.26.4 python: 3.10.0 (default, Apr 9 2024, 03:46:30) [GCC 9.4.0] jax.devices (1 total, 1 local): [rocm(id=0)] process_count: 1 platform: uname_result(system='Linux', node='mars', release='5.15.0-105-generic', version='#115-Ubuntu SMP Mon Apr 15 09:52:04 UTC 2024', machine='x86_64')

brettkoonce commented 1 month ago

still seeing this with rocm 6.1.1 and latest images!

platform: uname_result(system='Linux', node='mars', release='5.15.0-106-generic', version='#116-Ubuntu SMP Wed Apr 17 09:17:56 UTC 2024', machine='x86_64')