NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 310 forks source link

[Common] Fused cast transpose kernels refactoring #884

Closed Oleg-Goncharov closed 3 months ago

Oleg-Goncharov commented 4 months ago

Description

Existing code of the fused cast transpose kernels is replicated for different scenarios (i.e. +dbias, +dactivation) with only small specific modifications. Replacing it with a single function template makes the code easier to support and allows adding new features in a simple way (e.g., scaling, JIT compilation).

Type of change: code refactoring

Changes:

The following table provides the runtime of the previous and the new version of the fused cast transpose kernels on the H100 HBM3. The new version is benchmarked for two values of the n_warps_per_tile parameter (4 and 8):

Cast_transpose_fused-Template Code

Oleg-Goncharov commented 4 months ago

/te-ci

Oleg-Goncharov commented 4 months ago

/te-ci

phu0ngng commented 4 months ago

Hi, I think templating DACT and DBIAS is a good idea. Great work!

I have little concern as now we have gated_act_cast_transpose.cu but act_cast_transpose is still secretly in cast_transpose.cu. I think it is better to either keep all related cast transpose functions in one place OR split both gated_act and act related functions into two additional files.

Oleg-Goncharov commented 4 months ago

Agree, to have it consistent across other cast-transpose files, I reverted the split. There is a single file cast_transpose_fusion.cu as before.

Oleg-Goncharov commented 4 months ago

/te-ci