Closed norabelrose closed 2 years ago
First of all sorry for the late reply!
This looks great and it is something that we should have done anyway for all our kernels. (Deadlines are no excuses :-) ).
Meanwhile since your commit, Julien Demouth (@jdemouth-nvidia ) has made causal-dot-product about 10x faster so I merged that first. I can try to port what you did for his kernel as well and then merge everything to master. If on the other hand you want to do it yourself you are more than welcome.
Regardless if you still want to pursue that, it looks great and I will definitely be merging it in this or the next week.
Thanks! Angelos
@angeloskath Thanks for your response— glad to hear about the new optimizations!
I'll have a look at it tonight or tomorrow and see if I can resolve the conflicts. I reinstalled my CUDA toolkit and I got some DeepSpeed CUDA kernels working so I think I should be able to test this one on my own machine now. But if you beat me to it that's totally fine, it should be pretty straightforward. Just turning functions into templates and making sure to use the hack for __shared__
memory.
@angeloskath Update: I've made all the initial, obviously necessary changes to the newly optimized kernel on my norabelrose/fast-transformers:fast branch. I realized that it'll be a bit more involved to port this version than the previous version since it uses the special float4 type for vectorization. There doesn't seem to be a way of getting around writing some explicitly specialized code for half-precision there. If I understand this correctly it seems like the best way to do it would be to still use the float4 type to load/store from shared memory but then reinterpret_cast each element of the float4 to a half2 to do the actual math, so you'd be storing 8 scalar values in total for each float4.
I might be able to try coding that up later this week but since you're more experienced with CUDA programming, it would probably be better for you to do it if you have time.
Also, as a side note, the compiler emits an ungodly number of 'warning: expression has no effect" messages, both on lines 361 and 640 where __shfl_xor_sync<scalar_t>(...)
is used, and on line 701 where the actual kernel is launched. I'm sort of at a loss for why that's happening since lines 361 and 640 are mutating a local variable, and on line 701 passes the Lhma_params struct by value, but then the kernel mutates buffers using pointers within that struct, so the kernel launch definitely "has an effect." I didn't change either of those lines besides making them generic over scalar_t
. But it's probably some silly mistake on my part I'm just not seeing.
I'm trying to implement a version of this in OpenAI's new Triton DSL. It should allow the a simpler code base to run with FP16 or FP32. Still WIP, and happy to take in feedback: https://github.com/calclavia/Triton-Transformer/blob/master/ttx/attention/causal_product.py
Why was this closed? Any news on getting fp16 support?
This PR adds support for half- and double-precision floats to the causal dot product CUDA kernel. I tried to do it in the simplest way possible, minimizing the number of changes to the original code. I changed three things in
fast_transformers/causal_product/causal_product_cuda.cu
:1) The forward and backward kernels, along with the get_result function, are now templates, generic over a scalar_t parameter. 2) I replaced usages of the
atomicAdd
function withgpuAtomicAdd
, since it appears that there's no overload ofatomicAdd
for thec10::Half
type. 3) I replaced direct declarations of dynamically-sized__shared__
memory blocks with calls to a new inline convenience function,dynamic_generic_shared_memory<T>
, which implements the following simple hack in order to get around the fact that NVCC won't let you declare the same dynamically sized__shared__
memory block within different template instantiations:extern __shared__ __align__(sizeof(T)) unsigned char _shared_mem[];
return reinterpret_cast<T *>(_shared_mem);
See here for discussion.I use the
AT_DISPATCH_FLOATING_TYPES_AND_HALF()
macro to pick the correct kernel template instantiation based on the type of the query tensor, which seems to be the recommended way to do it according to the PyTorch docs.I haven't been able to really test this at all— all I know is that it compiles. I'm having weird issues with my NVCC/CUDA toolkit install in general where even the master branch of fast-transformers either won't compile, or it will compile but will fail with a dynamic linking error when you try to load it from Python. So I'm working on fixing that so I can test/use this myself. But hopefully any changes that might need to be made to this PR will be pretty simple, since as I said, I tried to minimize the number of changes I made to master. This is the first time I've ever done any CUDA kernel programming, but I'd really like to use CausalDotProduct with AMP so I figured I'd try implementing it myself.