Open chenmoneygithub opened 1 year ago
Hi @chenmoneygithub -
The sheet is accessible for me. Mixed precision will speedup will only speed up models on recent NVIDIA GPUs and Google TPUs. NVIDIA GPUs support using a mix of float16 and float32, while TPUs support a mix of bfloat16 and float32. More details you can find here.
On which hardware you are using mixed_bfloat16 and float32 ?
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
@mehtamansi29 Thanks for looking into that! I am not sure if the result is still valid, that's a benchmark I did before the first official release of Keras 3. The TPU was v3-8, which is a very old distribution as of today.
In short, we observed
mixed_bfloat16
in TPU is slower thanfloat32
in our model benchmarks. Please refer to this sheet (internal only) for comparison results.To reproduce in JAX backend, on TPU VM, use the command below:
To reproduce in TF backend, you need to modify the code to connect to TPU and use a TPU strategy.