idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

Local Product CUDA Kernel #51

Closed AndriyMulyar closed 3 years ago

AndriyMulyar commented 4 years ago

Nice library. I have a question regarding the local product (longformer sliding window) kernel you have implemented. If I am correctly interpreting the implementation here, the KQ^T operation is decomposed into blocks of size 64 along the num_queries dimension which are then dotted via the gemm implementation in cuBlas with a window of 64 +- context_window/2 keys. The local_context window for each query is then copied out with a custom copy kernel.

With this implementation, the dot products for a much larger context window than local_context are computed but then subsequently ignored. Since these computations already happen, is it true that setting local_context to any value in [2,64] would essentially not alter the latency of the implementation but likely improve the generalization ability of the end transformer (due to the larger context window of each layer)?

Thanks!

angeloskath commented 4 years ago

Hi Andriy,

Sorry for the belated reply. What you wrote is probably correct but I will try to first answer to your questions and then clarify the code a bit to make sure that we are on the same page.

  1. With this implementation, the dot products for a much larger context window than local_context are computed: Not a much larger, instead of performing 64x(local_context) dot products we perform 64x(local_context+64) dot products. So all in all we perform precisely 4,096 extra dot products. However, that does not mean that it costs more than not computing them. GPUs are SIMD so chances are we are going to waste some cycles anyway (due to if checks etc).
  2. Since these computations already happen, is it true that setting local_context to any value in [2,64] would essentially not alter the latency: Probably, but I cannot be sure. It depends on both the cuBLAS implementation and our copy kernel. However, small local contexts would definitely be sub-optimal.
  3. but likely improve the generalization ability of the end transformer: Most likely, I mean the context window is a hyper-parameter to be tuned anyway.

Implementation description

The local product is implemented as follows:

As you can see, we have at most an extra 64 dot products per query so at most 4,096 total extra dot products. Also the number 64 is a parameter so if needed in the future we could possibly have a dynamic dispatch for smaller context windows (although I don't think it would help because the GPU would not have enough work to do efficiently).

Let me know if I have helped.

Cheers, Angelos