idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

Add support for half- and double-precision floats to CausalDotProduct (CUDA) #74

Closed norabelrose closed 2 years ago

norabelrose commented 3 years ago

This PR adds support for half- and double-precision floats to the causal dot product CUDA kernel. I tried to do it in the simplest way possible, minimizing the number of changes to the original code. I changed three things in fast_transformers/causal_product/causal_product_cuda.cu:

1) The forward and backward kernels, along with the get_result function, are now templates, generic over a scalar_t parameter. 2) I replaced usages of the atomicAdd function with gpuAtomicAdd, since it appears that there's no overload of atomicAdd for the c10::Half type. 3) I replaced direct declarations of dynamically-sized __shared__ memory blocks with calls to a new inline convenience function, dynamic_generic_shared_memory<T>, which implements the following simple hack in order to get around the fact that NVCC won't let you declare the same dynamically sized __shared__ memory block within different template instantiations: extern __shared__ __align__(sizeof(T)) unsigned char _shared_mem[]; return reinterpret_cast<T *>(_shared_mem); See here for discussion.

I use the AT_DISPATCH_FLOATING_TYPES_AND_HALF() macro to pick the correct kernel template instantiation based on the type of the query tensor, which seems to be the recommended way to do it according to the PyTorch docs.

I haven't been able to really test this at all— all I know is that it compiles. I'm having weird issues with my NVCC/CUDA toolkit install in general where even the master branch of fast-transformers either won't compile, or it will compile but will fail with a dynamic linking error when you try to load it from Python. So I'm working on fixing that so I can test/use this myself. But hopefully any changes that might need to be made to this PR will be pretty simple, since as I said, I tried to minimize the number of changes I made to master. This is the first time I've ever done any CUDA kernel programming, but I'd really like to use CausalDotProduct with AMP so I figured I'd try implementing it myself.

angeloskath commented 3 years ago

First of all sorry for the late reply!

This looks great and it is something that we should have done anyway for all our kernels. (Deadlines are no excuses :-) ).

Meanwhile since your commit, Julien Demouth (@jdemouth-nvidia ) has made causal-dot-product about 10x faster so I merged that first. I can try to port what you did for his kernel as well and then merge everything to master. If on the other hand you want to do it yourself you are more than welcome.

Regardless if you still want to pursue that, it looks great and I will definitely be merging it in this or the next week.

Thanks! Angelos

norabelrose commented 3 years ago

@angeloskath Thanks for your response— glad to hear about the new optimizations!

I'll have a look at it tonight or tomorrow and see if I can resolve the conflicts. I reinstalled my CUDA toolkit and I got some DeepSpeed CUDA kernels working so I think I should be able to test this one on my own machine now. But if you beat me to it that's totally fine, it should be pretty straightforward. Just turning functions into templates and making sure to use the hack for __shared__ memory.

norabelrose commented 3 years ago

@angeloskath Update: I've made all the initial, obviously necessary changes to the newly optimized kernel on my norabelrose/fast-transformers:fast branch. I realized that it'll be a bit more involved to port this version than the previous version since it uses the special float4 type for vectorization. There doesn't seem to be a way of getting around writing some explicitly specialized code for half-precision there. If I understand this correctly it seems like the best way to do it would be to still use the float4 type to load/store from shared memory but then reinterpret_cast each element of the float4 to a half2 to do the actual math, so you'd be storing 8 scalar values in total for each float4.

I might be able to try coding that up later this week but since you're more experienced with CUDA programming, it would probably be better for you to do it if you have time.

Also, as a side note, the compiler emits an ungodly number of 'warning: expression has no effect" messages, both on lines 361 and 640 where __shfl_xor_sync<scalar_t>(...) is used, and on line 701 where the actual kernel is launched. I'm sort of at a loss for why that's happening since lines 361 and 640 are mutating a local variable, and on line 701 passes the Lhma_params struct by value, but then the kernel mutates buffers using pointers within that struct, so the kernel launch definitely "has an effect." I didn't change either of those lines besides making them generic over scalar_t. But it's probably some silly mistake on my part I'm just not seeing.

calclavia commented 3 years ago

I'm trying to implement a version of this in OpenAI's new Triton DSL. It should allow the a simpler code base to run with FP16 or FP32. Still WIP, and happy to take in feedback: https://github.com/calclavia/Triton-Transformer/blob/master/ttx/attention/causal_product.py

maxall41 commented 8 months ago

Why was this closed? Any news on getting fp16 support?