JuliaGPU / ArrayFire.jl

Julia wrapper for the ArrayFire library
Other
204 stars 37 forks source link

Incompatibility with Zygote #268

Open daniloefl opened 3 years ago

daniloefl commented 3 years ago

Dear ArrayFire developers,

it seems that ArrayFire.jl has a small compatibility issue with Zygote.

A test example follows in [1], but the core of the issue is at the fact that Zygote implements a function Zygote.accum, which just sums up gradients, and for AbstractArrays, it is defined as follows in [2]. It basically uses broadcasting to call itself, assuming it would call the non-AbstractArray-typed version of itself. Unfortunately, the ArrayFire broadcasting calls the same function with arguments still as AbstractArray, causing an endless loop.

The solution could be a simple override of this function for AFArray:

Zygote.accum(x::AFArray, y::AFArray) =
         x === nothing ? y :
         y === nothing ? x :
         x .+ y

With this override, it all works. I am not sure if other overrides are necessary in more general cases, though. Although the Zygote developers could be summoned here, this would create a dependency between Zygote and ArrayFire, which is not really necessary. I am not sure that there is a cleaner way of solving the issue.

Best regards, Danilo

[1] Test example:

using ArrayFire
using Flux
using DiffEqFlux
using Zygote

hyper = FastChain(FastDense(1, 10, tanh), FastDense(10, 10, tanh), FastDense(10, 16, tanh))
p = initial_params(hyper)
x = rand(Float32, 1, 100)

# This is require due to a separate indexing issue in DiffEqFlux (unrelated to this bug, not doing this override causes a crash due to another incompatibility, but I daresay this is an issue in DiffEqFlux):
DiffEqFlux.applychain(fs::Tuple, x, p) = DiffEqFlux.applychain(Base.tail(fs), first(fs)(x,p[1:DiffEqFlux.paramlength(first(fs))]), length(fs) > 1 ? p[(DiffEqFlux.paramlength(first(fs))+1):end] : Tuple{}())

af_p = AFArray(p)

# this works:
hyper(x, af_p)

# this does not
gs = Flux.gradient(params(af_p)) do
         sum(hyper(x, af_p))
         end

The error I get is:

julia> gs = Flux.gradient(params(af_p)) do
                sum(hyper(x, af_p))
                end
ERROR: StackOverflowError:
Stacktrace:
 [1] broadcasted(::Function, ::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/ArrayFire/U0hth/src/array.jl:217
 [2] accum(::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/lib/lib.jl:16
 [3] broadcasted(::Function, ::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/ArrayFire/U0hth/src/array.jl:220
 ... (the last 2 lines are repeated 16335 more times)
 [32674] accum(::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/lib/lib.jl:16
 [32675] applychain at ./REPL[7]:2 [inlined]
 [32676] (::typeof(∂(applychain)))(::FillArrays.Fill{Float32,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 ... (the last 2 lines are repeated 1 more time)
 [32679] FastChain at /home/daniloefl/.julia/packages/DiffEqFlux/8UHw5/src/fast_layers.jl:21 [inlined]
 [32680] (::typeof(∂(λ)))(::FillArrays.Fill{Float32,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [32681] #5 at ./REPL[16]:2 [inlined]
 [32682] (::typeof(∂(#5)))(::Float32) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [32683] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(#5))})(::Float32) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:177
 [32684] gradient(::Function, ::Zygote.Params) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:54

[2] From Zygote.jl/src/lib/lib.jl:

accum() = nothing
accum(x) = x

accum(x, y) =
  x === nothing ? y :
  y === nothing ? x :
  x + y

accum(x, y, zs...) = accum(accum(x, y), zs...)

accum(x::Tuple, y::Tuple) = accum.(x, y)
accum(x::AbstractArray, y::AbstractArray) = accum.(x, y)
daniloefl commented 3 years ago

NB: The other issue I mention in the example code comment above seems to be related with DiffEqFlux.jl itself and I have already filed an issue with their developers here:

https://github.com/SciML/DiffEqFlux.jl/issues/436

ghost commented 3 years ago

Unfortunately, the ArrayFire broadcasting calls the same function with arguments still as AbstractArray, causing an endless loop.

This is a leftover from pre v0.7 Julia broadcast days: ArrayFire is translating a broadcast into simple function calls, so exp.(afarray) is just calling exp(afarray) which in turn calls C af_exp. Proper fix would be to re-implement how broadcast is done.

daniloefl commented 3 years ago

Hello @gaika

I am using Julia 1.5.2. I don't understand your last comment. Can you tell me which steps you suggest?

Best regards Danilo

ghost commented 3 years ago

The bug is here: https://github.com/JuliaGPU/ArrayFire.jl/blob/master/src/array.jl#L217-L233

If you can fix it so broadcast goes directly in to C code then your Zygote (and whole other similar issues) will be gone.

daniloefl commented 3 years ago

The issue is that I don't actually want it to go straight to C in general (maybe for exp it makes sense, but here this is not the case). Zygote.accum is defined in Julia and also not in ArrayFire. Furthermore, I don't see anything wrong with its definition: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/lib.jl#L5

I am not sure what the best solution is, but it seems to me that there are two very different cases that need to be handled separately: 1) Broadcasts that are done within the GPU with an internal ArrayFire function (such as exp.(A)); and 2) Broadcasts implemented in Julia, which have nothing to do with an already implemented ArrayFire-function in general (such as the one in Zygote).