Open daniloefl opened 3 years ago
NB: The other issue I mention in the example code comment above seems to be related with DiffEqFlux.jl
itself and I have already filed an issue with their developers here:
Unfortunately, the
ArrayFire
broadcasting calls the same function with arguments still asAbstractArray
, causing an endless loop.
This is a leftover from pre v0.7 Julia broadcast days: ArrayFire is translating a broadcast into simple function calls, so exp.(afarray)
is just calling exp(afarray)
which in turn calls C af_exp
. Proper fix would be to re-implement how broadcast is done.
Hello @gaika
I am using Julia 1.5.2. I don't understand your last comment. Can you tell me which steps you suggest?
Best regards Danilo
The bug is here: https://github.com/JuliaGPU/ArrayFire.jl/blob/master/src/array.jl#L217-L233
If you can fix it so broadcast goes directly in to C code then your Zygote (and whole other similar issues) will be gone.
The issue is that I don't actually want it to go straight to C in general (maybe for exp
it makes sense, but here this is not the case). Zygote.accum
is defined in Julia and also not in ArrayFire. Furthermore, I don't see anything wrong with its definition:
https://github.com/FluxML/Zygote.jl/blob/master/src/lib/lib.jl#L5
I am not sure what the best solution is, but it seems to me that there are two very different cases that need to be handled separately: 1) Broadcasts that are done within the GPU with an internal ArrayFire function (such as exp.(A)
); and 2) Broadcasts implemented in Julia, which have nothing to do with an already implemented ArrayFire-function in general (such as the one in Zygote).
Dear ArrayFire developers,
it seems that ArrayFire.jl has a small compatibility issue with Zygote.
A test example follows in [1], but the core of the issue is at the fact that Zygote implements a function
Zygote.accum
, which just sums up gradients, and forAbstractArrays
, it is defined as follows in [2]. It basically uses broadcasting to call itself, assuming it would call the non-AbstractArray
-typed version of itself. Unfortunately, theArrayFire
broadcasting calls the same function with arguments still asAbstractArray
, causing an endless loop.The solution could be a simple override of this function for
AFArray
:With this override, it all works. I am not sure if other overrides are necessary in more general cases, though. Although the Zygote developers could be summoned here, this would create a dependency between Zygote and ArrayFire, which is not really necessary. I am not sure that there is a cleaner way of solving the issue.
Best regards, Danilo
[1] Test example:
The error I get is:
[2] From
Zygote.jl/src/lib/lib.jl
: