csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

cache var used by each iteration in grid persistent kernel, e.g. weight in layer norm backward #2525

Open liqiangxl opened 1 year ago

liqiangxl commented 1 year ago

🚀 The feature, motivation and pitch

In layer norm backward: For input DataType::Half, the persistent buffers are projected to three inputs (dy, x, weight), total size is 3 sizeof(half) dim1 For input DataType::Float the persistent buffers are NOT projected, they are xhat and d_xhat, the total size is 2 sizeof(float) dim1 If I enforce projection for input DataType::Float, there is a significiant speedup, e.g. for case 2048 x 10240 the time is reduced from 274 us to 207 us, for case 2048 x 1024 the time is reduced from 39 us to 36 us. The reason is because weight is shared across different rows. If we keep it persistent, we don't need to reload it in the iteration over different rows. The projected version needs more registers per thread but it doesn't reduce the occupancy ratio as the all the blocks must be active at the same time for this grid persistent kernel.

Alternatives

No response

Additional context

No response