Closed AndriyMulyar closed 3 years ago
Hi Andriy,
Sorry for the belated reply. What you wrote is probably correct but I will try to first answer to your questions and then clarify the code a bit to make sure that we are on the same page.
The local product is implemented as follows:
As you can see, we have at most an extra 64 dot products per query so at most 4,096 total extra dot products. Also the number 64 is a parameter so if needed in the future we could possibly have a dynamic dispatch for smaller context windows (although I don't think it would help because the GPU would not have enough work to do efficiently).
Let me know if I have helped.
Cheers, Angelos
Nice library. I have a question regarding the local product (longformer sliding window) kernel you have implemented. If I am correctly interpreting the implementation here, the KQ^T operation is decomposed into blocks of size 64 along the num_queries dimension which are then dotted via the gemm implementation in cuBlas with a window of 64 +- context_window/2 keys. The local_context window for each query is then copied out with a custom copy kernel.
With this implementation, the dot products for a much larger context window than local_context are computed but then subsequently ignored. Since these computations already happen, is it true that setting local_context to any value in [2,64] would essentially not alter the latency of the implementation but likely improve the generalization ability of the end transformer (due to the larger context window of each layer)?
Thanks!