@rahulbatra85
I managed to get jax to train resnet50 using the scenic library and the latest docker images, thanks to #18747.
When I change my code from float32 to float16, I get an approximate ~2x speedup. However, when I enable bfloat16, the system runs at the float32 speed. Am I correct in assuming that the docker + XTX + CDNA3 + ROCM 6.1 system is not generating bfloat16 code yet?
System info (python version, jaxlib version, accelerator, etc.)
Description
@rahulbatra85 I managed to get jax to train resnet50 using the scenic library and the latest docker images, thanks to #18747.
When I change my code from
float32
tofloat16
, I get an approximate ~2x speedup. However, when I enablebfloat16
, the system runs at the float32 speed. Am I correct in assuming that the docker + XTX + CDNA3 + ROCM 6.1 system is not generating bfloat16 code yet?System info (python version, jaxlib version, accelerator, etc.)
ubuntu 22.04 lts rocm 6.1 7900 xtx rocm/jax:latest image
jax: 0.4.26 jaxlib: 0.4.26 numpy: 1.26.4 python: 3.10.0 (default, Apr 9 2024, 03:46:30) [GCC 9.4.0] jax.devices (1 total, 1 local): [rocm(id=0)] process_count: 1 platform: uname_result(system='Linux', node='mars', release='5.15.0-105-generic', version='#115-Ubuntu SMP Mon Apr 15 09:52:04 UTC 2024', machine='x86_64')