FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Wrong gradient of a broadcasting over buffer expression #563

Open AzamatB opened 4 years ago

AzamatB commented 4 years ago

I discovered this bug during the implementation of the BLSTM layer in Julia. Here is the reproducing example (I tried to reduce it, but the bug keeps evading):

using Flux
using Flux: @functor, Recur, LSTMCell
using Zygote
using Zygote: Buffer

struct BLSTM{M <: DenseMatrix, V <: DenseVector}
   forward  :: Recur{LSTMCell{M,V}}
   backward :: Recur{LSTMCell{M,V}}
   outdim   :: Int
end

Flux.trainable(m::BLSTM) = (m.forward, m.backward)
@functor BLSTM

function BLSTM(in::Integer, out::Integer)
   forward  = LSTM(in, out)
   backward = LSTM(in, out)
   return BLSTM(forward, backward, Int(out)) |> gpu
end

function (m::BLSTM)(Xs::DenseArray{<:Real,3})
   Xs = permutedims(Xs, (1, 3, 2)) # [features × time × batch] -> [features × batch × time]
   # preallocate output buffer
   Ys = Buffer(Xs, 2m.outdim, size(Xs,3), size(Xs,2))
   axisYs₁ = axes(Ys, 1)
   time_f = axes(Ys, 2)
   time_b = reverse(time_f)
   # get forward and backward slice indices
   slice_f = axisYs₁[1:m.outdim]
   slice_b = axisYs₁[(m.outdim+1):end]
   # bidirectional run step
   setindex!.((Ys,),  m.forward.(view.((Xs,), :, :, time_f)), (slice_f,), time_f, :)
   setindex!.((Ys,), m.backward.(view.((Xs,), :, :, time_b)), (slice_b,), time_b, :)
   return copy(Ys) # [features × time × batch]
end

D, T, B = 5, 7, 4
m = BLSTM(D, D÷2)
θ = params(m)
x = rand(Float32, D, T, B) |> gpu
julia> gs = gradient(() -> sum(m(x)), θ)
Grads(...)

julia> for p ∈ θ
          @show gs[p]
       end
gs[p] = nothing
gs[p] = nothing
gs[p] = nothing
gs[p] = nothing
gs[p] = nothing
gs[p] = nothing
gs[p] = nothing
gs[p] = nothing
gs[p] = nothing
gs[p] = nothing

Now if I replace the broadcasting expression in the forward pass with for loop, it computes the gradients correctly

function (m::BLSTM)(Xs::DenseArray{<:Real,3})
   Xs = permutedims(Xs, (1, 3, 2)) # [features × time × batch] -> [features × batch × time]
   # preallocate output buffer
   Ys = Buffer(Xs, 2m.outdim, size(Xs,3), size(Xs,2))
   axisYs₁ = axes(Ys, 1)
   time_f = axes(Ys, 2)
   time_b = reverse(time_f)
   # get forward and backward slice indices
   slice_f = axisYs₁[1:m.outdim]
   slice_b = axisYs₁[(m.outdim+1):end]
   # bidirectional run step
   @views for (t_f, t_b) ∈ zip(time_f, time_b)
      Ys[slice_f, t_f, :] =  m.forward(Xs[:, :, t_f])
      Ys[slice_b, t_b, :] = m.backward(Xs[:, :, t_b])
   end
   return copy(Ys) # [features × time × batch]
end
julia> gs = gradient(() -> sum(m(x)), θ)
Grads(...)

julia> for p ∈ θ
          @show gs[p]
       end
gs[p] = Float32[0.5101168 0.7152076 0.73191303 0.499395 0.53787124; -0.64906085 -0.5532551 -0.6976129 -0.5867471 -0.53961146; 0.83282167 0.737208 0.79622936 0.7672653 0.6858342; -1.3397851 -1.0657941 -1.2099426 -1.0391285 -1.155204; 10.094865 8.342223 9.1953335 8.667656 8.725371; 3.9081998 3.2401364 3.5865738 3.3855665 3.5111141; 1.110752 1.0483335 0.9643827 0.8357487 0.8817357; -3.5713134 -3.1898558 -3.1710458 -3.1021874 -2.9224894]
gs[p] = Float32[0.15129018 -0.39292163; -0.16212332 0.3713716; 0.2302244 -0.47000247; -0.31288114 0.71879077; 2.3276496 -5.5051193; 0.924201 -2.1497304; 0.3037261 -0.6143787; -0.9125652 2.0442908]
gs[p] = Float32[1.2720479, -1.2275351, 1.5529392, -2.3684156, 18.073097, 7.08892, 1.979319, -6.576381]
gs[p] = nothing
gs[p] = nothing
gs[p] = Float32[-1.8887122 -1.6791371 -1.4836435 -1.5126536 -1.5024879; 0.8776411 0.8338405 0.54689103 0.6299987 0.6667515; -1.823601 -1.6581113 -1.4419314 -1.4415677 -1.4066488; 0.7889698 0.70566267 0.57653344 0.5739659 0.6011479; 1.9840837 1.7578385 1.289188 1.3200386 1.3633783; 1.4536479 1.435955 1.3797476 1.3699763 1.1075537; -2.5608983 -2.275478 -2.19252 -2.16371 -2.062906; 2.084642 1.8295443 1.8279495 1.7722094 1.7149571]
gs[p] = Float32[1.0405383 -0.6534959; -0.45830482 0.28144148; 1.0043947 -0.63153327; -0.43334833 0.27105567; -1.0180058 0.63447845; -0.92075074 0.5782949; 1.4423877 -0.9218084; -1.1742411 0.7575369]
gs[p] = Float32[-3.357407, 1.4713798, -3.2320528, 1.398481, 3.2924507, 2.97262, -4.6464305, 3.7910182]
gs[p] = nothing
gs[p] = nothing

Also, pinging @mcabbott, as previously discussed this with him.

mcabbott commented 4 years ago

Here's a much simpler example if anyone would like to play. I don't know why you don't get errors when you broadcast setindex! above. I also don't know what the benefit of writing it that way might be, if it worked.

using Zygote, ForwardDiff
struct F W end
(f::F)(x) = f.W * x
w = rand(3,2);
f = F(w)
xs = [rand(2) for _=1:5];
x = reduce(hcat, xs)

function bufloop(f, M, d)
    B = Zygote.Buffer(M, d, size(M,2))
    for i in axes(M,2)
        B[:,i] = f(M[:,i])
    end
    copy(B)
end
function bufmap(f, M, d)
    B = Zygote.Buffer(M, d, size(M,2))
    # map!(f, eachcol(B), eachcol(M)) # no eachcol method
    map!(f, [view(B,:,i) for i in axes(M,2)], collect(eachcol(M))) # no view?
    copy(B)
end
function bufcast(f, M, d)
    B = Zygote.Buffer(M, d, size(M,2))
    # view.(Ref(B), :,axes(M,2)) .= f.(eachcol(M)) # no view?
    cols = collect(eachcol(M))
    setindex!.(Ref(B), f.(cols), :, axes(M,2)) 
    # setindex!.(Ref(B), f.(view.(Ref(M), :, axes(M,2))), :, axes(M,2)) 
    copy(B)
end

fwd = mapslices(F(w), x, dims=1)
bufloop(F(w), x, 3)  # ok
# bufmap(F(w), x, 3) # error
bufcast(F(w), x, 3)  # ok

rev = ForwardDiff.gradient(w -> sum(mapslices(F(w), x, dims=1)), w)

Zygote.gradient(w -> sum(bufloop(F(w), x, 3)), w)[1]           # works
Zygote.gradient(()-> sum(bufloop(F(w), x, 3)), Params([w]))[w] # wrong answer but runs
Zygote.gradient(()-> sum(bufloop(f,    x, 3)), Params([w]))[w] # ok

Zygote.gradient(w -> sum(bufmap(F(w), x, 3)), w)[1] # error, view(::Zygote.Buffer)

Zygote.gradient(w -> sum(bufcast(F(w), x, 3)), w)[1] # error, iterate(::Nothing)

(The wrong gradients are the same issue as https://github.com/FluxML/Zygote.jl/issues/522#issuecomment-605935652 )

AzamatB commented 4 years ago

@mcabbott I'm not sure if it's the same bug though. Since my example above does not throw an error, but rather silently gives the wrong result

mcabbott commented 4 years ago

Well, what's different? I'm not even sure it's a bug, or just unsupported behaviour.... in which case perhaps not giving an error is the bug.