rusty1s / pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations
https://pytorch-scatter.readthedocs.io
MIT License
1.5k stars 178 forks source link

Poor performance with __half and __nv_bfloat16 #416

Open borisfom opened 5 months ago

borisfom commented 5 months ago

In atomicAdd overloads, native atomicAdd should be used for half and nv_bfloat16, instead of AtomicAddDecimalImpl. Like this:

#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000))
static inline __device__ void atomAdd(__half* address, __half val)
{   
    AtomicAddDecimalImpl<__half, sizeof(__half)>()(address, val);
}
#else
#if (__CUDA_ARCH__ >= 800)
static inline __device__ void atomAdd(__nv_bfloat16* address, __half val)
{   
    atomicAdd(address, val);
}
#endif
static inline __device__ void atomAdd(__half* address, __half val)
{   
    atomicAdd(address, val);
}
#endif
rusty1s commented 5 months ago

Do you mind sending a PR to fix?

borisfom commented 5 months ago

I sure will! I have noticed, however, that when I install pytorch_scatter, I end up calling Torch's scatter_add instead anyway. Is current scatter code here obsolete ? In any case, by reading PyTorch code I learned the matter is a bit more complicated and allegedly native atomicAdd(nv_bfloat16) is very slow so they end up using __nv_bfloat162 for it with a questionable trick - perf was not great either. So I am going to investigate options further. The most attractive option is to use nv_bfloat162 throughout, but that would require changes to algorithm and I am not sure even possible. What do you think about that ?

rusty1s commented 5 months ago

Yeah, that is correct. We just use the scatter_add implementation from PyTorch. As such, the scatter_add implementation in torch-scatter is indeed kinda obsolete by now.

borisfom commented 5 months ago

Thanks for the confirmation! I guess there is no point of doing PR then.

borisfom commented 5 months ago

If you have an idea ho to implement same scatter semantics with __nv_bfloat162 type, that may be something that PyT folks can use! As it stands now, the best thing to do is to convert bf16 to float before scatter and then convert back - way faster than trying to do it in bf16.