Open marius311 opened 3 years ago
(with a little guidance on the strategy, I could probably hazard a PR myself if it'd be easier)
Couple of things:
dot
that's cublasDotEx
which @kshyatt added in https://github.com/JuliaGPU/CUDA.jl/pull/904 but only for Float16 types. Other combinations of types are supported too, see https://docs.nvidia.com/cuda/cublas/index.html#cublas-dotEx. Either we express this through methods, or we do it imperatively as with gemmEx
(which is much more complicated, in terms of supported types): https://github.com/JuliaGPU/CUDA.jl/blob/c77e98549a33b770017ea9cc09b950a7f47d3ab7/lib/cublas/wrappers.jl#L807-L907gemm
this is a complicated routine that checks whether gemmEx
is supported: https://github.com/JuliaGPU/CUDA.jl/blob/c77e98549a33b770017ea9cc09b950a7f47d3ab7/lib/cublas/linalg.jl#L176-L225In addition, with dot
we could add a fallback method (i.e. without element-type constraints) that just does a'*b
, falling back to the well-optimized GEMM implementation. On the other hand, when GEMM fails to select a fast implementation it'll use the horribly slow GPUArrays.jl implementation, in which case sum(a.*b)
might be a better fallback (which uses two kernels, and an intermediate allocation, so there's quite some overhead)
These don't seem to currently work (CUDA 3.3):
as they fall back to a generic which triggers scalar indexing. It would be nice to have these implemented, even with a simple
sum(conj.(x) .* y)
or something, which at least works on GPU.Maybe worth mentioning, but I didn't really care about these until upgrading Zygote 0.6.11 -> 0.6.12, since a recent change (bisected down to https://github.com/FluxML/Zygote.jl/pull/973) seems to make it so Zygote emits such dot products where previously it wasn't. Here's an example which triggers scalar indexing after that commit but not before:
Anyway, this isn't really relevant for CUDA.jl, but figured might provide some context.