Closed cossio closed 2 years ago
If I remove the custom adjoint, it works fine. So it has to be something related to the interaction between params and a custom @adjoint
struct S
W::Array{Float64}
end
Flux.@functor S
s = S(randn(4,4))
ps = params(s)
foo(s::S) = sum(s.W)
gs = gradient(ps) do
foo(s)
end
gs[s.W] # correct gradient
A workaround is to write an intermediary function that takes only array inputs:
struct S
W::Array{Float64}
end
Flux.@functor S
s = S(randn(4,4))
ps = params(s)
fff(s::S) = _fff(s.W)
_fff(w) = sum(sin.(w))
@adjoint function _fff(w)
_fff(w), Δ -> (similar(w) .= Δ .* cos.(w),)
end
gs = gradient(ps) do
fff(s)
end
gs[s.W] # correct gradients
@MikeInnes Any idea what is happening here? This issue is producing wrong gradients silently, and it took me a while just to figure out the bug originated in Zygote.
I don't believe Zygote can track implicit params usage in adjoint functions (by design I assume, otherwise there'd be no way to avoid AD in custom adjoints). So if s.W
doesn't show up in a place that the AD has visibility over, it won't have a gradient. Is there any reason you can't work with a structural gradient for s?
To clarify a bit more, your custom adjoint means that the computation graph that the AD system works with looks like:
s -> f(s) -> output
In other words, it never "sees" the array s.W
. The intermediate function avoids this by:
s -> getproperty(s, :W) -> _fft(w) -> output
Here, Zygote returns your custom adjoint for the gradient w.r.t. w
for _fft(w)
, then on the call to getproperty(s, :W)
is where the AD "sees" s.W
as an implicit array and accumulates your custom adjoint output into it.
Similarly, without the custom adjoint, the AD has
s -> getproperty(s, :W) -> sum(x) -> output
Here, sum(x)
plays a similar role to _fft(w)
. Basically the AD actually "seeing" the array that gets returned by getproperty
is key here.
@darsnack You consider this resolved?
Resolved == can't fix? (anyone feel free to reopen in case I'm wrong)
Yes, my understanding is that this by design for adjoints. Writing a custom rule forces the AD to look away, and I don't think we would merge a change that breaks that fundamental assumption.
The only alternative fix I see on Zygote's end would be to do post-pullback accumulating into implicit params since Zygote does the structural gradient anyways. This would require recursively traversing all the values. Maybe @mcabbott can comment on the correctness/feasibility of this.
The recommended fixes here are to:
_fft
so that the struct serves as syntactic sugar in your program.s
directly. This will give you the Main.s
result as the gradient.Structural gradient for reference:
gradient(fun, S(rand(2, 2)))
Ok, thanks! I will also put here this example from @ToucheSir (posted on the Slack) for future reference.
struct S
W::Array{Float64}
end
s = S(randn(4,4))
fun(s::S) = sum(s.W)
@adjoint function fun(s::S)
fun(s), Δ -> ((; W = similar(s.W) .= Δ),)
end
julia> gs = gradient(s) do s
fun(s)
end
((W = [1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0],),)
@darsnack One more question. Is there an alternative way I could have written the above explicit adjoint so that this example works fine?
I don't believe Zygote can track implicit params usage in adjoint functions (by design I assume, otherwise there'd be no way to avoid AD in custom adjoints). So if s.W doesn't show up in a place that the AD has visibility over, it won't have a gradient. Is there any reason you can't work with a structural gradient for s?
I would go even further and say Zygote should remove the implicit parameter system, or at least Flux should. It seemed like a good idea but 3 years later I think we've all learned it only causes pain. The main gain was syntactic sugar, but the system underlying it never really was that solid. This is just one of many unsolvable issues that arise from it, others being performance or compile time related, along with other weird correctness edge cases. Instead, we should all explore different ways to make explicit parameters have similarly nice syntax, and that would be the best of all worlds.
We have the explicit form, and that would be good to use. Elsewhere, implicit gradients are tracked over the same rules and explicit ones. There's no design constraint over why one should work and another not.
@ChrisRackauckas I would love for nothing more, but there unfortunately hasn't been a big push behind figuring out how to bring explicit params to parity, let alone a migration plan. This includes non-syntactic issues such as how to do tied weights and how to exclude certain params from optimization. I myself have at least a couple pages of design notes on various aspects/challenges, and looking at what others are doing it's clear this is not a trivial task!
Anyhow, I just created a tracking project at https://github.com/orgs/FluxML/projects/2. Please add new issues/tasks as you encounter them—it would be great to record all this disparate discussion about implicit vs explicit params in one place.
I have encountered this issue several times. This is the smallest example I was able to find to reproduce it.
I noticed that
gs
is storing the correct gradients inW
in anotherkey
, which equalsMain.s
, but I'm not even sure what that is and I cannot access it.But the correct key
s.W
is populated withnothing
, which is wrong.What is going on here?