csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

warp reduction for complex numbers #2526

Closed liqiangxl closed 1 year ago

liqiangxl commented 1 year ago

🚀 The feature, motivation and pitch

warp reduction is disabled for complex numbers since __shfl_xor_sync() doesn't support complex number. The supported data types are:

_

T can be int, unsigned int, long, unsigned long, long long, unsigned long long, float or double. With the cuda_fp16.h header included, T can also be half or half2. Similarly, with the cuda_bf16.h header included, T can also be __nv_bfloat16 or __nv_bfloat162.

_

To add additional support may try: exchange the real and imag part of the complex separately. Need check performance if we are doing shfl twice.

https://docs.nvidia.com/cuda/cuda-c-programming-guide/#:~:text=T%20can%20be%20int%2C%20unsigned%20int%2C%20long%2C%20unsigned%20long%2C%20long%20long%2C%20unsigned%20long%20long%2C%20float%20or%20double.%20With%20the%20cuda_fp16.h%20header%20included%2C%20T%20can%20also%20be%20__half%20or%20__half2.%20Similarly%2C%20with%20the%20cuda_bf16.h%20header%20included%2C%20T%20can%20also%20be%20__nv_bfloat16%20or%20__nv_bfloat162.

Alternatives

No response

Additional context

No response