In layer norm backward:
For input DataType::Half, the persistent buffers are projected to three inputs (dy, x, weight), total size is 3 sizeof(half) dim1
For input DataType::Float the persistent buffers are NOT projected, they are xhat and d_xhat, the total size is 2 sizeof(float) dim1
If I enforce projection for input DataType::Float, there is a significiant speedup, e.g. for case 2048 x 10240 the time is reduced from 274 us to 207 us, for case 2048 x 1024 the time is reduced from 39 us to 36 us. The reason is because weight is shared across different rows. If we keep it persistent, we don't need to reload it in the iteration over different rows. The projected version needs more registers per thread but it doesn't reduce the occupancy ratio as the all the blocks must be active at the same time for this grid persistent kernel.
🚀 The feature, motivation and pitch
In layer norm backward: For input DataType::Half, the persistent buffers are projected to three inputs (dy, x, weight), total size is 3 sizeof(half) dim1 For input DataType::Float the persistent buffers are NOT projected, they are xhat and d_xhat, the total size is 2 sizeof(float) dim1 If I enforce projection for input DataType::Float, there is a significiant speedup, e.g. for case 2048 x 10240 the time is reduced from 274 us to 207 us, for case 2048 x 1024 the time is reduced from 39 us to 36 us. The reason is because weight is shared across different rows. If we keep it persistent, we don't need to reload it in the iteration over different rows. The projected version needs more registers per thread but it doesn't reduce the occupancy ratio as the all the blocks must be active at the same time for this grid persistent kernel.
Alternatives
No response
Additional context
No response