Closed lxvm closed 10 months ago
Update: since we have adjoints for product
and collect
, I added an adjoint for collect(product())
intended to work with this example
using Zygote, Test
@test Zygote.gradient(x -> sum(broadcast(prod, Iterators.product(x .^ 2,x))), ones(4)) == (3*4ones(4),)
Sure, I'll try and get a test written soon, although testing inference can be tricky. Ideally Test.@inferred
would work well, but it also has to pick up a regression and I'll try to figure it out tomorrow.
Additionally, I can update the Iterators.zip
adjoints to match the product iterator
I've finished adding inference and correctness tests for product
and zip
. For zip
, some care has to be taken for Number
s, since they are iterable. It looks like the CI errors are unrelated. Does this look good?
Update: In general, project
ing onto the input iterator type is unlikely to work for custom iterators, take for example a range as seen here:
julia> Zygote.gradient(x -> sum(prod, x), 1:5)
([1.0, 1.0, 1.0, 1.0, 1.0],)
julia> Zygote.gradient(x -> sum(prod, zip(x)), 1:5)
(nothing,)
Any thoughts on how to improve handling this?
Correction: project
ion is not the reason the example above gives nothing
, but I'm not sure what is (can open a separate issue). Still, shouldn't the adjoints for the iterators be able to handle the projection themselves, as explained here?
Although it does use ChainRules' projection machinery sometimes, Zygote overall doesn't do projection quite the same way because it predates ChainRules. The legacy projection machinery we do have can be rather inconsistent. In this particular case however, it looks like the input type is causing null gradients?
julia> gradient(x -> sum(prod, zip(x)), collect(1:5))
([1.0, 1.0, 1.0, 1.0, 1.0],)
julia> gradient(x -> sum(prod, zip(x)), 1.0:5.0)
([1.0, 1.0, 1.0, 1.0, 1.0],)
It makes sense that an integer range would be considered non-differentiable, but it would be good to confirm Zygote is doing this for the right reason and not because of some bug. Either way, if you can't figure out zip
easily I'd just leave it for a follow-up PR and we can try to get the product
changes in first.
Thank you for the context about projection in Zygote. I'm happy with keeping it as is to have this pr be as non-breaking as possible.
Otherwise the work on zip
is done and I added tests that are equivalent to those for product
. I did switch the adjoint for the constructor Iterators.Zip
to one for the function Iterators.zip
and I'm not sure if something depended especially on the former since there were previously no tests for it.
As for the observation of null gradients for integer ranges, it appears to have nothing to do with zip
and everything to do with iteration. Here are some more cases
julia> Zygote.gradient(x -> sum(prod, Iterators.product(x)), 1:5)
(nothing,)
julia> Zygote.gradient(x -> sum(prod, Iterators.map(identity, x)), 1:5)
(nothing,)
julia> Zygote.gradient(x -> sum(prod, Iterators.take(x,5)), 1:5)
(nothing,)
I'd have to understand where the decision is being made, but I think it's safe to leave it to a follow-up.
It turns out the answer is easier than I thought: https://github.com/FluxML/Zygote.jl/blob/54f1e807d5c098a6100c422b54d14f8cd85e0b3c/src/lib/array.jl#L252
This may have been for supporting for x in 1:N ...
.
Zygote often throws away gradients of a UnitRange, here:
It's not enforced by projection, so things that hit other rules such as gradient(x -> sum(abs2, x), 1:5)
don't give nothing.
I did switch the adjoint for the constructor Iterators.Zip to one for the function Iterators.zip and I'm not sure if something depended especially on the former since there were previously no tests for it.
I have no memory of why, but when initially writing these rules, attaching them to the uppercase constructor not the lowercase function somehow made more cases work. There are tests here but fewer than I thought.
I've rebased this branch onto master and resolved the last issue I was concerned about. Looks like the CI is mostly good, although I'm not sure if the DynamicPPL failure is related
This is a great contribution for a tricky set of rules, thanks @lxvm !
Thanks to everyone for the helpful support!
Hi,
I've returned to my first contribution in #1170 since I noticed I couldn't differentiate w.r.t
Iterator.product
s that have a number as an iterator. This pr adds a test and fixes the issue while also improving the inferrability of the adjoint.PR Checklist