keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.69k stars 19.42k forks source link

`mixed_bfloat16` in TPU is slower than `float32` #18448

Open chenmoneygithub opened 1 year ago

chenmoneygithub commented 1 year ago

In short, we observed mixed_bfloat16 in TPU is slower than float32 in our model benchmarks. Please refer to this sheet (internal only) for comparison results.

To reproduce in JAX backend, on TPU VM, use the command below:

cd benchmarks
KERAS_BACKEND=jax python3 -m model_benchmark.image_classification_benchmark  \
   --model="ResNet50V2"  \
   --epochs=1 \
   --batch_size=32 \ 
   --mixed_precision_policy="mixed_bfloat16"

To reproduce in TF backend, you need to modify the code to connect to TPU and use a TPU strategy.

mehtamansi29 commented 2 weeks ago

Hi @chenmoneygithub -

The sheet is accessible for me. Mixed precision will speedup will only speed up models on recent NVIDIA GPUs and Google TPUs. NVIDIA GPUs support using a mix of float16 and float32, while TPUs support a mix of bfloat16 and float32. More details you can find here.

On which hardware you are using mixed_bfloat16 and float32 ?

github-actions[bot] commented 4 days ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

chenmoneygithub commented 3 days ago

@mehtamansi29 Thanks for looking into that! I am not sure if the result is still valid, that's a benchmark I did before the first official release of Keras 3. The TPU was v3-8, which is a very old distribution as of today.