I am incorporating some simple non-standard elements in a larger DNN.
To achieve this I had to manually define some reverse rules through ChainRulesCore syntax
ChainRulesCore.rrule(::typeof(f),args...)
Which works great on the cpu, but somehow is not portable to the gpu.
I narrowed it to a MWE that fails the GPU compilation of the pullback function
using Zygote
using CUDA
f(A) = map(A) do a
a
end
A = rand(Float32,128,128)
gA = gpu(A)
f(A) # works
f(gA) # works
Δ,pb = pullback(f,A) #works
Δ,pb = pullback(f,gA) #fails
I am incorporating some simple non-standard elements in a larger DNN.
To achieve this I had to manually define some reverse rules through ChainRulesCore syntax
ChainRulesCore.rrule(::typeof(f),args...)
Which works great on the cpu, but somehow is not portable to the gpu.I narrowed it to a MWE that fails the GPU compilation of the pullback function
Any insights?