fluxcd / flux

Successor: https://github.com/fluxcd/flux2
https://fluxcd.io
Apache License 2.0
6.9k stars 1.08k forks source link

Zygote Flux and custom adjoints on GPU #3584

Closed TsurHerman closed 2 years ago

TsurHerman commented 2 years ago

I am incorporating some simple non-standard elements in a larger DNN.

To achieve this I had to manually define some reverse rules through ChainRulesCore syntax

ChainRulesCore.rrule(::typeof(f),args...)

Which works great on the cpu, but somehow is not portable to the gpu.

I narrowed it to a MWE that fails the GPU compilation of the pullback function

using Zygote
using CUDA

f(A) = map(A) do a
    a
end

A = rand(Float32,128,128)
gA = gpu(A)
f(A) # works
f(gA) # works

Δ,pb = pullback(f,A) #works

Δ,pb = pullback(f,gA) #fails

Any insights?

yebyen commented 2 years ago

Sorry, this is the Flux project from fluxcd.io

(Not the different one from fluxml.ai which is what I guessed you wanted)