ankan-ban / llama2.cu

Inference Llama 2 in one file of pure Cuda
MIT License
16 stars 2 forks source link

speed-up softmax #4

Closed kroggen closed 1 year ago

kroggen commented 1 year ago

This change pre-computes arr[h * size] into arr_base

But instead of using a new variable arr_base:

half* __restrict__ arr_base = arr + h * size;

It could just use the value on the same arr variable:

arr = arr + h * size;

or

arr += h * size;
kroggen commented 1 year ago

The num_heads is currently not being used on the softmax() function

It can be removed

kroggen commented 1 year ago

I confess I do not understand how this softmax implementation works, because the reduction operation happens in a single block, but this kernel is called with many blocks (num_heads):

softmax_kernel <<< num_heads, 1024 >>> (att, num_heads, seq_len);
ankan-ban commented 1 year ago

Did you notice any speed improvement with this change? I would have expected the compiler to move that outside the loop anyway. In any case it makes it slightly more readable, so I will take this change. Thanks.

I confess I do not understand how this softmax implementation works, because the reduction operation happens in a single block, but this kernel is called with many blocks (num_heads):

softmax_kernel <<< num_heads, 1024 >>> (att, num_heads, seq_len);

If you see the CPU code, you will notice that the softmax is applied independently to each head, so it can be done in parallel. Here we are doing softmax for each head in it's own thread block.

kroggen commented 1 year ago

Did you notice any speed improvement with this change? I would have expected the compiler to move that outside the loop anyway. In any case it makes it slightly more readable, so I will take this change. Thanks.

I don't remember the benchmark results for this case.

I did not know that the nvcc could make such an optimization.

If you see the CPU code, you will notice that the softmax is applied independently to each head, so it can be done in parallel. Here we are doing softmax for each head in it's own thread block.

Cool!