FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Improve adjoint for product and zip #1489

Closed lxvm closed 10 months ago

lxvm commented 10 months ago

Hi,

I've returned to my first contribution in #1170 since I noticed I couldn't differentiate w.r.t Iterator.products that have a number as an iterator. This pr adds a test and fixes the issue while also improving the inferrability of the adjoint.

PR Checklist

lxvm commented 10 months ago

Update: since we have adjoints for product and collect, I added an adjoint for collect(product()) intended to work with this example

using Zygote, Test
@test Zygote.gradient(x -> sum(broadcast(prod, Iterators.product(x .^ 2,x))), ones(4)) == (3*4ones(4),)
lxvm commented 10 months ago

Sure, I'll try and get a test written soon, although testing inference can be tricky. Ideally Test.@inferred would work well, but it also has to pick up a regression and I'll try to figure it out tomorrow.

Additionally, I can update the Iterators.zip adjoints to match the product iterator

lxvm commented 10 months ago

I've finished adding inference and correctness tests for product and zip. For zip, some care has to be taken for Numbers, since they are iterable. It looks like the CI errors are unrelated. Does this look good?

Update: In general, projecting onto the input iterator type is unlikely to work for custom iterators, take for example a range as seen here:

julia> Zygote.gradient(x -> sum(prod, x), 1:5)
([1.0, 1.0, 1.0, 1.0, 1.0],)

julia> Zygote.gradient(x -> sum(prod, zip(x)), 1:5)
(nothing,)

Any thoughts on how to improve handling this?

Correction: projection is not the reason the example above gives nothing, but I'm not sure what is (can open a separate issue). Still, shouldn't the adjoints for the iterators be able to handle the projection themselves, as explained here?

ToucheSir commented 10 months ago

Although it does use ChainRules' projection machinery sometimes, Zygote overall doesn't do projection quite the same way because it predates ChainRules. The legacy projection machinery we do have can be rather inconsistent. In this particular case however, it looks like the input type is causing null gradients?

julia> gradient(x -> sum(prod, zip(x)), collect(1:5))
([1.0, 1.0, 1.0, 1.0, 1.0],)

julia> gradient(x -> sum(prod, zip(x)), 1.0:5.0)
([1.0, 1.0, 1.0, 1.0, 1.0],)

It makes sense that an integer range would be considered non-differentiable, but it would be good to confirm Zygote is doing this for the right reason and not because of some bug. Either way, if you can't figure out zip easily I'd just leave it for a follow-up PR and we can try to get the product changes in first.

lxvm commented 10 months ago

Thank you for the context about projection in Zygote. I'm happy with keeping it as is to have this pr be as non-breaking as possible.

Otherwise the work on zip is done and I added tests that are equivalent to those for product. I did switch the adjoint for the constructor Iterators.Zip to one for the function Iterators.zip and I'm not sure if something depended especially on the former since there were previously no tests for it.

As for the observation of null gradients for integer ranges, it appears to have nothing to do with zip and everything to do with iteration. Here are some more cases

julia> Zygote.gradient(x -> sum(prod, Iterators.product(x)), 1:5)
(nothing,)

julia> Zygote.gradient(x -> sum(prod, Iterators.map(identity, x)), 1:5)
(nothing,)

julia> Zygote.gradient(x -> sum(prod, Iterators.take(x,5)), 1:5)
(nothing,)

I'd have to understand where the decision is being made, but I think it's safe to leave it to a follow-up.

ToucheSir commented 10 months ago

It turns out the answer is easier than I thought: https://github.com/FluxML/Zygote.jl/blob/54f1e807d5c098a6100c422b54d14f8cd85e0b3c/src/lib/array.jl#L252 This may have been for supporting for x in 1:N ....

mcabbott commented 10 months ago

Zygote often throws away gradients of a UnitRange, here:

https://github.com/FluxML/Zygote.jl/blob/54f1e807d5c098a6100c422b54d14f8cd85e0b3c/src/lib/array.jl#L252

It's not enforced by projection, so things that hit other rules such as gradient(x -> sum(abs2, x), 1:5) don't give nothing.

I did switch the adjoint for the constructor Iterators.Zip to one for the function Iterators.zip and I'm not sure if something depended especially on the former since there were previously no tests for it.

I have no memory of why, but when initially writing these rules, attaching them to the uppercase constructor not the lowercase function somehow made more cases work. There are tests here but fewer than I thought.

lxvm commented 10 months ago

I've rebased this branch onto master and resolved the last issue I was concerned about. Looks like the CI is mostly good, although I'm not sure if the DynamicPPL failure is related

ToucheSir commented 10 months ago

This is a great contribution for a tricky set of rules, thanks @lxvm !

lxvm commented 10 months ago

Thanks to everyone for the helpful support!