Closed ChrisRackauckas closed 4 years ago
Is this coming from destructure, or restructure, or does it only appear in the adjoint for one of those, or what?
Those functions don't have scalar indexing inherently, so the issue is likely either a function that CuArrays doesn't have an implementation for (on the forward pass) or an issue with one of Zygote's adjoints for those functions.
I think it's in the adjoint of the restructure. destructure
isn't called in the gradient
call there.
reshape(xs[i.+(1:length(x))], size(x))
this reshape may not work well in all cases. If I had to guess it may bit hitting something similar to https://github.com/JuliaGPU/CuArrays.jl/issues/548 ?
Possibly. If it's an issue with a Zygote adjoint, it'd be useful to narrow down which specific function is causing the issue. It could be a combination if it's a wrapper issue, but that's less likely.
yeah it's not a wrapper issue:
function _restructure(m, xs)
i = 0
fmap(m) do x
x isa AbstractArray || return x
x = adapt(typeof(x),reshape(xs[i.+(1:length(x))], size(x)))
i += length(x)
return x
end
end
it's definitely restructure though since you can make this as minimal as possible and it shows up with just restructure
:
using Flux, Zygote, CuArrays
CuArrays.allowscalar(false)
dudt = Dense(1,1) |> gpu
p,re = Flux.destructure(dudt)
Zygote.gradient(x->re(p)(x)[1],cu(rand(1)))
using Flux, Zygote, CuArrays
CuArrays.allowscalar(false)
dudt = Dense(1,1) |> gpu
p,re = Flux.destructure(dudt)
foo(x) = sum(re(p)(x))
y, back = Zygote._pullback(foo, cu(rand(1)))
back(1)
is nicer
so this should be fixed by https://github.com/FluxML/Zygote.jl/pull/474 and tests are in https://github.com/FluxML/Flux.jl/pull/998. Thanks!