Open jonny-so opened 4 years ago
@willtebbutt
@MikeInnes is this intended behaviour? My guess is that because indexing into an array of arrays creates a newly-referenced object, tracking the reference of the array breaks down?
"Flux" => v"0.10.1"
Using Flux.params(p...)
might work?
Yeah, there are a couple of work arounds -- I'm more interested in why this doesn't work. If it's intended behaviour, that's fine, but we should have an FAQ entry (and more documentation generally around parameter handling). Obviously if it's a bug then it needs to be fixed.
Yes, this is a bug. When you pull a value out of a (global) array (or struct etc.) Zygote needs to check if that value is something it should be tracking, and if so accumulate its gradient globally too. accum_param
handles that and the code for getproperty
(structs) is here; the code that needs to be added to getindex
is very similar. It's also worth checking tuples since they also have their own indexing functions.
Aside: I think there may be an unhandled edge case here if you create a struct-of-params within the target function; we might then double-accumulate the gradient. So if the value is a param then the gradient should be accumulated and not backpropagated. This would be worth looking into.
I have an example which I was baffled by for a while, which I think must be precisely the double-accumulation mentioned.
using Zygote, ForwardDiff
struct F W end
(f::F)(x) = f.W * x
w = rand(3,2);
xs = [rand(2) for _=1:5];
ForwardDiff.gradient(w -> sum(sum, map(F(w), xs)), w) # ground truth
Zygote.gradient( w -> sum(sum, map(F(w), xs)), w)[1] # right
g = Zygote.gradient(() -> sum(sum, map(F(w), xs)), Zygote.Params([w]))
g[w] # wrong
g[w] ./ 2 # right
f = F(w) # construct the struct outside the loss function
gf = Zygote.gradient(() -> sum(sum, map(f, xs)), Zygote.Params([w]))
gf[w] # right
@mcabbott your example here seems to be fixed on master.
Indeed, thanks. But the original one is not.
Could this be related to https://github.com/FluxML/Zygote.jl/issues/692?
displays
whereas
displays
nothing