FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.53k stars 608 forks source link

Zygote Flux and custom adjoints on GPU #1828

Closed TsurHerman closed 2 years ago

TsurHerman commented 2 years ago

I am incorporating some simple non-standard elements in a larger DNN.

To achieve this I had to manually define some reverse rules through ChainRulesCore syntax ChainRulesCore.rrule(::typeof(f),args...) Which works great on the cpu, but somehow is not portable to the gpu.

I narrowed it to a MWE that fails the GPU compilation of the pullback function

using Zygote
using CUDA

f(A) = map(A) do a
    a
end

A = rand(Float32,128,128)
gA = gpu(A)
f(A) # works
f(gA) # works

Δ,pb = pullback(f,A) #works

Δ,pb = pullback(f,gA) #fails

Any insights?

DhairyaLGandhi commented 2 years ago

map on CUDA arrays is broken at the moment, you can use broadcast instead.

TsurHerman commented 2 years ago

Yes this solves the problem You can close the issue or leave it as reference until this gap is mitigated