JuliaGPU / GPUArrays.jl

Reusable array functionality for Julia's various GPU backends.
MIT License
313 stars 74 forks source link

Remove special-casing of Ref in broadcast. #510

Closed maleadt closed 6 months ago

maleadt commented 6 months ago

In https://github.com/JuliaGPU/GPUArrays.jl/pull/240 / https://github.com/JuliaGPU/GPUArrays.jl/commit/7807069d1fcc843e4889df5600a2425c95fa8cb2, I added the ability to broadcast with Refs of GPU arrays and still use the GPU broadcast style. This is convenient to be able to access elements 'outside' of what's being broadcasted:

julia> lookup_table = jl(rand(10));
julia> data = jl(rand(1:10, 2, 2));

julia> broadcast(data, Ref(lookup_table)) do x, lookup_table
           lookup_table[x]
       end
2×2 JLArray{Float64, 2}:
 0.42458  0.407246
 0.42458  0.482623

(Ab)using Ref like that isn't great, though. For one, it breaks the ability to broadcast with GPU array objects on the CPU, as noted by @ToucheSir:

julia> x = jl([1])
1-element JLArray{Int64, 1}:
 1

julia> xs = [copy(x)]
1-element Vector{JLArray{Int64, 1}}:
 [1]

julia> xs .= Ref(x)
ERROR: This object is not a GPU array

Furthermore, there's nowadays a simpler way to accomplish the above behavior, namely by capturing additional arrays outside of the broadcast:

julia> function demo()
           lookup_table = jl(rand(10))
           data = jl(rand(1:10, 2, 2))
           broadcast(data) do x
               lookup_table[x]
           end
       end

julia> demo()
2×2 JLArray{Float64, 2}:
 0.42458  0.407246
 0.42458  0.482623

So in this PR, I remove the special Ref behavior. That's a breaking change, so may require changes downstream. Looking at the original PR, I feel like I may have added this at the request of @ChrisRackauckas, so it may be prudent to check the SciML stack uses this pattern, and if it could be updated to use captures instead.

Fixes https://github.com/JuliaGPU/GPUArrays.jl/issues/505

ChrisRackauckas commented 6 months ago

I've always considered using Ref in broadcast a style error, and preferred using tuples to do it:

x = [rand(5) for i in 1:5]
y = rand(5)
z = x .+ (y,)

since the compiler infers it a lot better. So if that still works fine then SciML should be good, and I think our style guide "enforces" this.

maleadt commented 6 months ago

That code is allocating a vector of vectors though, so was never supposed to run directly (i.e. the outer broadcast) on the GPU. So that behavior is unchanged:

julia> x = [jl(rand(5)) for i in 1:5]
5-element Vector{JLArray{Float64, 1}}:
 [0.12528405590764002, 0.44645909392009553, 0.2696314861947734, 0.5269768328096068, 0.09627383550638124]
 [0.6121100687597992, 0.2663646221991879, 0.39010923009673526, 0.4466685936484782, 0.5568502880466518]
 [0.1605983787982691, 0.8632715126999709, 0.9841132855962932, 0.817389209967217, 0.5730778757136838]
 [0.4205266752754707, 0.12814316624764188, 0.24065136095536133, 0.04061488703288718, 0.02559930838997082]
 [0.2933998629643869, 0.3496545372178772, 0.5778249397950344, 0.9940432389025162, 0.47729002372509555]

julia> y = jl(rand(5))
5-element JLArray{Float64, 1}:
 0.6350116147386349
 0.1403931418040173
 0.8814164152048923
 0.66297803654671
 0.17378635387723118

julia> z = x .+ (y,)
5-element Vector{JLArray{Float64, 1}}:
 [0.7602956706462749, 0.5868522357241128, 1.1510479013996657, 1.189954869356317, 0.2700601893836124]
 [1.2471216834984342, 0.4067577640032052, 1.2715256453016277, 1.1096466301951882, 0.730636641923883]
 [0.795609993536904, 1.0036646545039882, 1.8655297008011855, 1.4803672465139268, 0.746864229590915]
 [1.0555382900141055, 0.2685363080516592, 1.1220677761602538, 0.7035929235795971, 0.199385662267202]
 [0.9284114777030218, 0.4900476790218945, 1.4592413549999268, 1.6570212754492262, 0.6510763776023267]
maleadt commented 6 months ago

Seeing how this doesn't trip up any downstream CI, and SciML is apparently fine with it, I propose to try this out without releasing a breaking version. Behaving differently from Array can also be considered a bug, after all.