Open AzamatB opened 4 years ago
Here's a much simpler example if anyone would like to play. I don't know why you don't get errors when you broadcast setindex!
above. I also don't know what the benefit of writing it that way might be, if it worked.
using Zygote, ForwardDiff
struct F W end
(f::F)(x) = f.W * x
w = rand(3,2);
f = F(w)
xs = [rand(2) for _=1:5];
x = reduce(hcat, xs)
function bufloop(f, M, d)
B = Zygote.Buffer(M, d, size(M,2))
for i in axes(M,2)
B[:,i] = f(M[:,i])
end
copy(B)
end
function bufmap(f, M, d)
B = Zygote.Buffer(M, d, size(M,2))
# map!(f, eachcol(B), eachcol(M)) # no eachcol method
map!(f, [view(B,:,i) for i in axes(M,2)], collect(eachcol(M))) # no view?
copy(B)
end
function bufcast(f, M, d)
B = Zygote.Buffer(M, d, size(M,2))
# view.(Ref(B), :,axes(M,2)) .= f.(eachcol(M)) # no view?
cols = collect(eachcol(M))
setindex!.(Ref(B), f.(cols), :, axes(M,2))
# setindex!.(Ref(B), f.(view.(Ref(M), :, axes(M,2))), :, axes(M,2))
copy(B)
end
fwd = mapslices(F(w), x, dims=1)
bufloop(F(w), x, 3) # ok
# bufmap(F(w), x, 3) # error
bufcast(F(w), x, 3) # ok
rev = ForwardDiff.gradient(w -> sum(mapslices(F(w), x, dims=1)), w)
Zygote.gradient(w -> sum(bufloop(F(w), x, 3)), w)[1] # works
Zygote.gradient(()-> sum(bufloop(F(w), x, 3)), Params([w]))[w] # wrong answer but runs
Zygote.gradient(()-> sum(bufloop(f, x, 3)), Params([w]))[w] # ok
Zygote.gradient(w -> sum(bufmap(F(w), x, 3)), w)[1] # error, view(::Zygote.Buffer)
Zygote.gradient(w -> sum(bufcast(F(w), x, 3)), w)[1] # error, iterate(::Nothing)
(The wrong gradients are the same issue as https://github.com/FluxML/Zygote.jl/issues/522#issuecomment-605935652 )
@mcabbott I'm not sure if it's the same bug though. Since my example above does not throw an error, but rather silently gives the wrong result
Well, what's different? I'm not even sure it's a bug, or just unsupported behaviour.... in which case perhaps not giving an error is the bug.
I discovered this bug during the implementation of the BLSTM layer in Julia. Here is the reproducing example (I tried to reduce it, but the bug keeps evading):
Now if I replace the broadcasting expression in the forward pass with for loop, it computes the gradients correctly
Also, pinging @mcabbott, as previously discussed this with him.