LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
456 stars 54 forks source link

NCCL Complex wrapper #747

Open nikopj opened 2 weeks ago

nikopj commented 2 weeks ago

The NCCL backend in distributed utils does not support complex values see issue. Can we add a conveinent wrapper in the NCCLEXT to support broadcast, reduce, etc., likely via using reim and complex? I'm happy to get started on it but would like some feedback on a preferred location and so on.

avik-pal commented 2 weeks ago

There are 3 possible solutions:

  1. for the cases with LuxCUDADevice (https://github.com/LuxDL/Lux.jl/blob/68238456942b016f03dc40d47d7f312a7587c7a5/ext/LuxMPINCCLExt.jl#L38-L42), do the reim and complex if the buffer type is ::AbstractArray{<:Complex}.
  2. Alternatively, you can detect complex numbers the same way as point 1 and forward the call to MPI similar to https://github.com/LuxDL/Lux.jl/blob/68238456942b016f03dc40d47d7f312a7587c7a5/ext/LuxMPINCCLExt.jl#L44-L47. If MPI is CUDA-aware then no device to host copy is performed and MPI AFAIK supports Complex numbers.
  3. Add directly to NCCL.jl, though it might be worth opening an issue there and asking if the feature is welcome.