FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 209 forks source link

ArrayFire support #672

Open Nick-Gale opened 4 years ago

Nick-Gale commented 4 years ago

Do you think it is worth supporting ArrayFire as a way of getting GPU acceleration for AMD cards? I naively tried to make it work by loading in AFArrays but I have had trouble with broadcasting and seeing any significant speed-up in general.

Have others managed to make this work? Any help would be appreciated.

Nick-Gale commented 4 years ago

A hack that might be helpful for others is to fork the Zygote package and in src/lib/lib.jl change the definition for accum to:

accum(x::AbstractArray, y::AbstractArray) = x + y

I then defined the custom adjoints as follows:

@adjoint broadcasted(::typeof(+), x::AFArray, y::AFArray) = x .+ y, Δ -> (nothing, ArrayFire.add(Δ, similar(y), true), ArrayFire.add(Δ, similar(x), true))

@adjoint broadcasted(::typeof(-), x::AFArray, y::AFArray) = x .- y, Δ -> (nothing, ArrayFire.add(Δ, similar(y), true), ArrayFire.sub(similar(x), Δ, true))

@adjoint broadcasted(::typeof(*), x::AFArray, y::AFArray) = x .* y, Δ -> (nothing, ArrayFire.mul(Δ, y, true), ArrayFire.mul(Δ, x, true))

@adjoint function broadcasted(::typeof(/), x::AFArray, y::AFArray)
    res = x ./ y
    res, Δ -> (nothing, ArrayFire.div(Δ, y, true), ArrayFire.mul(-Δ, res ./ y, true))
end

@adjoint ctranspose(x::AFArray) = ctranspose(x), Δ -> (ctranspose(Δ), )

@adjoint exp(x::AFArray) = exp(x), Δ -> (similar(x) .+ Δ .* exp(x), )

@adjoint erf(x::AFArray) = erf(x), Δ -> (similar(x) .- Δ .* exp(- x .* x), )

sum(x::AFArray) = sum_all(x)[1]
@adjoint sum(x::AFArray) = sum_all(x)[1], Δ -> (ArrayFire.constant(Δ, size(x)), )

These changes seem to be working fine for my purposes but it is not really an ideal solution. I think part of the problem comes from how ArrayFire is handling broadcasting.