FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.49k stars 213 forks source link

gradients of functions that implicitly use vectors of parameters #522

Open jonny-so opened 4 years ago

jonny-so commented 4 years ago
p = [randn(5), randn(5)]
p1 = p[1]
p2 = p[2]
function myloss()
    norm(p1) + norm(p2)
end
grad = gradient(Flux.Zygote.Params(p)) do
  myloss()
end
display(grad[p[1]])

displays

5-element Array{Float64,1}:
 -0.16481553854066222
 -0.2893911650465771 
  0.41283960683427506
  0.8453105091581221 
 -0.06404837222204707

whereas

p = [randn(5), randn(5)]
function myloss()
    norm(p[1]) + norm(p[2])
end
grad = gradient(Flux.Zygote.Params(p)) do
  myloss()
end
display(grad[p[1]])

displays nothing

jonny-so commented 4 years ago

@willtebbutt

willtebbutt commented 4 years ago

@MikeInnes is this intended behaviour? My guess is that because indexing into an array of arrays creates a newly-referenced object, tracking the reference of the array breaks down?

jonny-so commented 4 years ago

"Flux" => v"0.10.1"

DhairyaLGandhi commented 4 years ago

Using Flux.params(p...) might work?

willtebbutt commented 4 years ago

Yeah, there are a couple of work arounds -- I'm more interested in why this doesn't work. If it's intended behaviour, that's fine, but we should have an FAQ entry (and more documentation generally around parameter handling). Obviously if it's a bug then it needs to be fixed.

MikeInnes commented 4 years ago

Yes, this is a bug. When you pull a value out of a (global) array (or struct etc.) Zygote needs to check if that value is something it should be tracking, and if so accumulate its gradient globally too. accum_param handles that and the code for getproperty (structs) is here; the code that needs to be added to getindex is very similar. It's also worth checking tuples since they also have their own indexing functions.

Aside: I think there may be an unhandled edge case here if you create a struct-of-params within the target function; we might then double-accumulate the gradient. So if the value is a param then the gradient should be accumulated and not backpropagated. This would be worth looking into.

mcabbott commented 4 years ago

I have an example which I was baffled by for a while, which I think must be precisely the double-accumulation mentioned.

using Zygote, ForwardDiff

struct F W end
(f::F)(x) = f.W * x
w = rand(3,2);
xs = [rand(2) for _=1:5];

ForwardDiff.gradient(w -> sum(sum, map(F(w), xs)), w) # ground truth
Zygote.gradient(     w -> sum(sum, map(F(w), xs)), w)[1] # right

g = Zygote.gradient(() -> sum(sum, map(F(w), xs)), Zygote.Params([w]))
g[w]      # wrong
g[w] ./ 2 # right

f = F(w) # construct the struct outside the loss function
gf = Zygote.gradient(() -> sum(sum, map(f, xs)), Zygote.Params([w]))
gf[w] # right
AzamatB commented 4 years ago

@mcabbott your example here seems to be fixed on master.

mcabbott commented 4 years ago

Indeed, thanks. But the original one is not.

cossio commented 4 years ago

Could this be related to https://github.com/FluxML/Zygote.jl/issues/692?