zhuhaozhe / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
0 stars 1 forks source link

USE the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/CONV/RNN #5

Open zhuhaozhe opened 7 months ago

zhuhaozhe commented 7 months ago

🚀 The Feature

This RFC proposes to use BFloat16 for GEMM/CONV/RNN internal computations on CPU device with user controlled frontend API. Currently, we have torch.set_float32_matmul_precision which allow float32 matrix multiplications in lower precision.

Frontend changes:

These frontend API should work under the same behavior with torch.set_float32_matmul_precision and torch.get_float32_matmul_precision. Users can set the precision to highest, high, and medium. When the precision is high, CUDA/CUDNN backend will be allowed to use TF32 as the internal computation data type. When the precision is medium, the MKLDNN backend will be allowed to use BF16 as the internal computation data type.

Backend changes:

Then We will use BF16 as the internal computation data type, PR is already created.

Inductor changes:

Motivation

A new instruction set of BF16 TMUL on Intel XEON server product can improve user application performance. With these frontend API, users can control internal computation data types for GEMM/CONV/RNN even when the model's data type is FLoat32. This will

Pitch

Provide float32_conv_precision and float32_rnn_precision and enable bfloat16 datatype for internal computations with MKLDNN backend when precision is set to medium

Additional context

Design option

Front end API:

Design option

Inductor linear packable rules: