FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Gradient over implicit parameters returns nothing #692

Closed cossio closed 2 years ago

cossio commented 4 years ago

I have encountered this issue several times. This is the smallest example I was able to find to reproduce it.

using Flux, Zygote
using Zygote: @adjoint
struct S
    W::Array{Float64}
end
Flux.@functor S
s = S(randn(4,4))
ps = params(s)
fun(s::S) = sum(s.W)
@adjoint function fun(s::S)
    fun(s), Δ -> ((; W = similar(s.W) .= Δ),)
end
gs = gradient(ps) do
    fun(s)
end
gs[s.W] # nothing
gradient(w -> fun(S(w)), randn(2,2)) # correct gradient

I noticed that gs is storing the correct gradients in W in another key, which equals Main.s, but I'm not even sure what that is and I cannot access it.

julia> gs.grads
IdDict{Any,Any} with 2 entries:
  [-0.453576 0.131353 0.0619522 0.126699; -0.172607 0.306845 0.522566 1.4498; 0.82781 -0.222564 -0.104318 0.0206807; -0.47… => nothing
  :(Main.s)  => (W = [1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0],)

But the correct key s.W is populated with nothing, which is wrong.

What is going on here?

cossio commented 4 years ago

If I remove the custom adjoint, it works fine. So it has to be something related to the interaction between params and a custom @adjoint

struct S
    W::Array{Float64}
end
Flux.@functor S
s = S(randn(4,4))
ps = params(s)
foo(s::S) = sum(s.W)
gs = gradient(ps) do
    foo(s)
end
gs[s.W] # correct gradient
cossio commented 4 years ago

A workaround is to write an intermediary function that takes only array inputs:

struct S
    W::Array{Float64}
end
Flux.@functor S
s = S(randn(4,4))
ps = params(s)
fff(s::S) = _fff(s.W)
_fff(w) = sum(sin.(w))
@adjoint function _fff(w)
    _fff(w), Δ -> (similar(w) .= Δ .* cos.(w),)
end
gs = gradient(ps) do
    fff(s)
end
gs[s.W] # correct gradients
cossio commented 4 years ago

@MikeInnes Any idea what is happening here? This issue is producing wrong gradients silently, and it took me a while just to figure out the bug originated in Zygote.

ToucheSir commented 2 years ago

I don't believe Zygote can track implicit params usage in adjoint functions (by design I assume, otherwise there'd be no way to avoid AD in custom adjoints). So if s.W doesn't show up in a place that the AD has visibility over, it won't have a gradient. Is there any reason you can't work with a structural gradient for s?

darsnack commented 2 years ago

To clarify a bit more, your custom adjoint means that the computation graph that the AD system works with looks like:

s -> f(s) -> output

In other words, it never "sees" the array s.W. The intermediate function avoids this by:

s -> getproperty(s, :W) -> _fft(w) -> output

Here, Zygote returns your custom adjoint for the gradient w.r.t. w for _fft(w), then on the call to getproperty(s, :W) is where the AD "sees" s.W as an implicit array and accumulates your custom adjoint output into it.

Similarly, without the custom adjoint, the AD has

s -> getproperty(s, :W) -> sum(x) -> output

Here, sum(x) plays a similar role to _fft(w). Basically the AD actually "seeing" the array that gets returned by getproperty is key here.

cossio commented 2 years ago

@darsnack You consider this resolved?

darsnack commented 2 years ago

Resolved == can't fix? (anyone feel free to reopen in case I'm wrong)

Yes, my understanding is that this by design for adjoints. Writing a custom rule forces the AD to look away, and I don't think we would merge a change that breaks that fundamental assumption.

The only alternative fix I see on Zygote's end would be to do post-pullback accumulating into implicit params since Zygote does the structural gradient anyways. This would require recursively traversing all the values. Maybe @mcabbott can comment on the correctness/feasibility of this.

The recommended fixes here are to:

darsnack commented 2 years ago

Structural gradient for reference:

gradient(fun, S(rand(2, 2)))
cossio commented 2 years ago

Ok, thanks! I will also put here this example from @ToucheSir (posted on the Slack) for future reference.

struct S
    W::Array{Float64}
end
s = S(randn(4,4))
fun(s::S) = sum(s.W)
@adjoint function fun(s::S)
    fun(s), Δ -> ((; W = similar(s.W) .= Δ),)
end
julia> gs = gradient(s) do s
           fun(s)
       end
((W = [1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0],),)
cossio commented 2 years ago

@darsnack One more question. Is there an alternative way I could have written the above explicit adjoint so that this example works fine?

ChrisRackauckas commented 2 years ago

I don't believe Zygote can track implicit params usage in adjoint functions (by design I assume, otherwise there'd be no way to avoid AD in custom adjoints). So if s.W doesn't show up in a place that the AD has visibility over, it won't have a gradient. Is there any reason you can't work with a structural gradient for s?

I would go even further and say Zygote should remove the implicit parameter system, or at least Flux should. It seemed like a good idea but 3 years later I think we've all learned it only causes pain. The main gain was syntactic sugar, but the system underlying it never really was that solid. This is just one of many unsolvable issues that arise from it, others being performance or compile time related, along with other weird correctness edge cases. Instead, we should all explore different ways to make explicit parameters have similarly nice syntax, and that would be the best of all worlds.

DhairyaLGandhi commented 2 years ago

We have the explicit form, and that would be good to use. Elsewhere, implicit gradients are tracked over the same rules and explicit ones. There's no design constraint over why one should work and another not.

ToucheSir commented 2 years ago

@ChrisRackauckas I would love for nothing more, but there unfortunately hasn't been a big push behind figuring out how to bring explicit params to parity, let alone a migration plan. This includes non-syntactic issues such as how to do tied weights and how to exclude certain params from optimization. I myself have at least a couple pages of design notes on various aspects/challenges, and looking at what others are doing it's clear this is not a trivial task!

Anyhow, I just created a tracking project at https://github.com/orgs/FluxML/projects/2. Please add new issues/tasks as you encounter them—it would be great to record all this disparate discussion about implicit vs explicit params in one place.