FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.51k stars 608 forks source link

destructure/restructure is doing scalar indexing on GPU in back pass #989

Closed ChrisRackauckas closed 4 years ago

ChrisRackauckas commented 4 years ago
using Flux, CuArrays
CuArrays.allowscalar(false)
dudt = Chain(Dense(2,50,tanh),Dense(50,2)) |> gpu
p,re = Flux.destructure(dudt)
Zygote.gradient(x->sum(re(p)(x)),cu(rand(2)))
MikeInnes commented 4 years ago

Is this coming from destructure, or restructure, or does it only appear in the adjoint for one of those, or what?

Those functions don't have scalar indexing inherently, so the issue is likely either a function that CuArrays doesn't have an implementation for (on the forward pass) or an issue with one of Zygote's adjoints for those functions.

ChrisRackauckas commented 4 years ago

I think it's in the adjoint of the restructure. destructure isn't called in the gradient call there.

ChrisRackauckas commented 4 years ago

reshape(xs[i.+(1:length(x))], size(x)) this reshape may not work well in all cases. If I had to guess it may bit hitting something similar to https://github.com/JuliaGPU/CuArrays.jl/issues/548 ?

MikeInnes commented 4 years ago

Possibly. If it's an issue with a Zygote adjoint, it'd be useful to narrow down which specific function is causing the issue. It could be a combination if it's a wrapper issue, but that's less likely.

ChrisRackauckas commented 4 years ago

yeah it's not a wrapper issue:

function _restructure(m, xs)
  i = 0
  fmap(m) do x
    x isa AbstractArray || return x
    x = adapt(typeof(x),reshape(xs[i.+(1:length(x))], size(x)))
    i += length(x)
    return x
  end
end

it's definitely restructure though since you can make this as minimal as possible and it shows up with just restructure:

using Flux, Zygote, CuArrays
CuArrays.allowscalar(false)
dudt = Dense(1,1) |> gpu
p,re = Flux.destructure(dudt)
Zygote.gradient(x->re(p)(x)[1],cu(rand(1)))
ChrisRackauckas commented 4 years ago
using Flux, Zygote, CuArrays
CuArrays.allowscalar(false)
dudt = Dense(1,1) |> gpu
p,re = Flux.destructure(dudt)
foo(x) = sum(re(p)(x))
y, back = Zygote._pullback(foo, cu(rand(1)))
back(1)

is nicer

CarloLucibello commented 4 years ago

so this should be fixed by https://github.com/FluxML/Zygote.jl/pull/474 and tests are in https://github.com/FluxML/Flux.jl/pull/998. Thanks!