FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Always calling `reverse` on `map` breaks stuff #1393

Open simsurace opened 1 year ago

simsurace commented 1 year ago

1376 broke this for example (this used to work on v0.6.55):

https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/issues/355

Could the change be reverted (or maybe there is an obvious fix) and a new patch be released?

mcabbott commented 1 year ago

Is there a minimal example?

When calling map with a stateful function, pre-1376 behaviour was just to blindly apply the pullbacks in the wrong order, except on a vector, which sometimes gave silently wrong answers. Present behaviour is to error if reverse doesn't work.

willtebbutt commented 1 year ago

I believe it boils down to

julia> using LinearAlgebra

julia> X = UpperTriangular(randn(3, 3))
3×3 UpperTriangular{Float64, Matrix{Float64}}:
 -0.211524  -1.40637   -0.315358
   ⋅         0.933569  -0.797453
   ⋅          ⋅        -0.339659

julia> reverse(X)
ERROR: ArgumentError: cannot set index in the lower triangular part (3, 2) of an UpperTriangular matrix to a nonzero value (-1.4063656013863817)
Stacktrace:
  [1] setindex!
    @ ~/.julia/juliaup/julia-1.8.5+0.x64.apple.darwin14/share/julia/stdlib/v1.8/LinearAlgebra/src/triangular.jl:233 [inlined]
  [2] _setindex!
    @ ./abstractarray.jl:1374 [inlined]
  [3] setindex!
    @ ./abstractarray.jl:1344 [inlined]
  [4] _reverse!(A::UpperTriangular{Float64, Matrix{Float64}}, dims::Tuple{Int64, Int64})
    @ Base ./arraymath.jl:89
  [5] _reverse!
    @ ./arraymath.jl:71 [inlined]
  [6] reverse!(A::UpperTriangular{Float64, Matrix{Float64}}; dims::Function)
    @ Base ./arraymath.jl:70
  [7] _reverse(A::UpperTriangular{Float64, Matrix{Float64}}, dims::Function)
    @ Base ./arraymath.jl:60
  [8] #reverse#247
    @ ./arraymath.jl:59 [inlined]
  [9] reverse(A::UpperTriangular{Float64, Matrix{Float64}})
    @ Base ./arraymath.jl:59
 [10] top-level scope
    @ REPL[3]:1

So if you hit reverse in the reverse-pass of AD and you've mapped over an UpperTriangular or got an UpperTriangular cotangent, you can't reverse.

Present behaviour is to error if reverse doesn't work.

This is breaking in all cases where people were (correctly) using non-mutating functions (using a mutating function with map is a dubious choice in any case, because the semantics of map don't guarantee ordering) and for whatever reason reverse doesn't work on their type. Should this not have prompted at least a minor version bump?

simsurace commented 1 year ago

I'm trying to come up with one that is more minimal, but I can't right now. The call of reverse on UpperTriangular is definitely at play, but I don't know where that UpperTriangular was coming from. Btw. something like

Zygote.gradient(m -> reduce(+, m), UpperTriangular(rand(3, 3)))

produces the same type of error, but is maybe unrelated because it doesn't work on 0.6.55 either.

ToucheSir commented 1 year ago

Should this not have prompted at least a minor version bump?

That would've required tests to catch it. Historically the whole map path has been woefully undertested, which would be bad for most rules but especially so for complex, higher-order ones like map. Oh, and the type signature for the rule was way too broad: over-promise, under-deliver and all that. Writing historical wrongs like this proactively requires time and effort that Zygote does not have at its disposal in the present day, so the best we can do is fix them when they come up (like here).

willtebbutt commented 1 year ago

I understand your position on this, but you've got to appreciate that it's infuriating to have patch releases in Zygote breaking my code, when the code I've written is entirely legit (it would be a different matter if I were depending on buggy behaviour).

mcabbott commented 1 year ago

That reverse(UpperTriangular(rand(3,3))) fails is arguably a bug in LinearAlgebra, which would be trivial to solve there, although not trivial to backport. We could add a special path for that here too, but it needs a test otherwise we are just guessing.

A Zygote MWE has to be a gradient which uses this path, and got the correct answer before. I have yet to invent one, e.g. gradient((x,y) -> sum(abs2, map(z -> z/y, x)), (UpperTriangular([1 2; 3 4])), 5)[1] works fine. That's an impure but not stateful function, and I thought Zygote only had paths for pure (no reverse) and impure (reverse). So these statements are not correct:

So if you hit reverse in the reverse-pass of AD and you've mapped over an UpperTriangular

This is breaking in all cases where people were (correctly) using non-mutating functions

Edit: In my example, map seems to be producing a Matrix on the forward pass, hence the gradient it receives will not be an UpperTriangular. ... Edit' like so:

julia> let y = 2 
        pullback(x -> map(z -> z/y, x), UpperTriangular([1 2; 3 4]))[1]
       end
2×2 Matrix{Float64}:
 0.5  1.0
 0.0  2.0

julia> map(x -> x/2, UpperTriangular([1 2; 3 4]))
2×2 UpperTriangular{Float64, Matrix{Float64}}:
 0.5  1.0
  ⋅   2.0
willtebbutt commented 1 year ago

So these statements are not correct:

That's fair. I should have been more specific in my wording: This is breaking in all cases where people were using functions which aren't singleton types and whose application order doesn't matter, and for whatever reason reverse doesn't work on the cotangent passed into the pullback associated with map.

Here's a MWE:

using Zygote, LinearAlgebra
struct Foo x::Float64 end
(f::Foo)(x) = f.x * x
x = randn(5, 5)
out, pb = Zygote.pullback(x -> map(Foo(5.0), x), x)
pb(UpperTriangular(randn(5, 5)))

Something along the lines of the following will yield an UpperTriangular cotangent:

only(Zygote.gradient(x -> sum(cholesky(Symmetric(x)).U), diagm(ones(5))))
mcabbott commented 1 year ago

Ok. Those could be made into a test.

Can you make an upstream issue about reverse(::UpperTriangular)? Until that's fixed Zygote could have a method something like _tryreverse(::typeof(map), x::UpperTriangular) = LowerTriangular(reverse(parent(x)))

willtebbutt commented 1 year ago

Can you make an upstream issue about reverse(::UpperTriangular)?

I'm actually not sure what aspect of the behaviour of reverse(::UpperTriangular) is incorrect as I'm not totally familiar with the semantics of reverse. Would you mind making it?

Until that's fixed Zygote could have a method something like _tryreverse(::typeof(map), x::UpperTriangular) = LowerTriangular(reverse(parent(x)))

That would certainly fix the immediate problem.

mcabbott commented 1 year ago

Clearly these should be ==, and arguably the second should be LowerTriangular:

julia> reverse(UpperTriangular([1 2 3; 4 5 6; 7 8 9]) |> collect)
3×3 Matrix{Int64}:
 9  0  0
 6  5  0
 3  2  1

julia> reverse(UpperTriangular([1 2 3; 4 5 6; 7 8 9]))
ERROR: ArgumentError: cannot set index in the lower triangular part (3, 2) of an UpperTriangular matrix to a nonzero value (2)
ToucheSir commented 1 year ago

Another idea is to use applicable. Something like _tryreverse(x) = applicable(reverse, x) ? reverse(x) : reverse(collect(x)). It will not give the most natural cotangent type in many cases, but given both map and Zygote aren't super strict about maintaining input types that's unlikely to be an issue. This also doesn't preclude more specialized methods such as the UpperTriangular overload above (which otherwise would pass the applicable test as a false positive).

simsurace commented 1 year ago

Another idea is to use applicable. Something like _tryreverse(x) = applicable(reverse, x) ? reverse(x) : reverse(collect(x)). It will not give the most natural cotangent type in many cases, but given both map and Zygote aren't super strict about maintaining input types that's unlikely to be an issue. This also doesn't preclude more specialized methods such as the UpperTriangular overload above (which otherwise would pass the applicable test as a false positive).

I like this generic fallback because the issue will probably occur with most matrix wrappers.

mcabbott commented 1 year ago

the issue will probably occur with most matrix wrappers

reverse(Diagonal([1,2,3])) is fine, reverse(Symmetric([1 2; 3 4])) is an error, again an upstream bug really. For user types, an error which clearly indicates that your type needs to support reverse might be better than a silent copy / move-to-CPU.

applicable test as a false positive

Notice this -- applicable sees that reverse(::AbstractArray) exists, but the failure is later. So this test would only apply to stranger types, but what exactly I'm not sure. If you get a NamedTuple here (structural gradient, for some iterator perhaps?) then neither path will go well.

ToucheSir commented 1 year ago

That's correct. I was thinking of map on iterators and other exotic types. Part of this is that map itself advertises itself as too general (thought at least the fallback works, unlike reverse...). Looking back through old issues I agree with @willtebbutt's argument in https://github.com/FluxML/Zygote.jl/issues/646 that the rule should've been defined constructively for known well-behaving types. In the meantime, maybe we can throw a runtime error approximating ChainRules' Notimplemented for anything which doesn't support reverse?

simsurace commented 1 year ago

the issue will probably occur with most matrix wrappers

reverse(Diagonal([1,2,3])) is fine, reverse(Symmetric([1 2; 3 4])) is an error, again an upstream bug really. For user types, an error which clearly indicates that your type needs to support reverse might be better than a silent copy / move-to-CPU.

applicable test as a false positive

Notice this -- applicable sees that reverse(::AbstractArray) exists, but the failure is later. So this test would only apply to stranger types, but what exactly I'm not sure. If you get a NamedTuple here (structural gradient, for some iterator perhaps?) then neither path will go well.

Thanks, I did not appreciate these points.

But does this being classified as an upstream bug mean that we have to wait for another Julia minor release? Shouldn't Zygote take into account that reverse does not perfectly fill this role, and use a custom function that can be appropriately overloaded in all known cases for types in Base and the standard library, and only forward to reverse for those cases in which it is known to work?

If we made the rule use _reverse, we could fix the triangular and symmetric cases without type piracy.

mcabbott commented 1 year ago

map itself advertises itself as too general

It is nice that things like this work:

julia> gradient(x -> sum(abs2, @showtype map(Base.splat(/), @showtype zip(x, 1:3))), [4,5,6,7])
typeof(fwd) = Base.Iterators.Zip{Tuple{Vector{Int64}, UnitRange{Int64}}}  # rule for zip does not collect this
typeof(fwd) = Vector{Float64}  # but map(f, ::Zip) just makes an Array. Possibly shared with collect(::Generator)?
([8.0, 2.5, 1.3333333333333333, 0.0],)

julia> gradient((x,y) -> sum(abs2, map(((a,b),) -> a/(b+y), zip(x, 1:3))), [4,5,6,7], 8)
([0.09876543209876543, 0.1, 0.09917355371900825, 0.0], -0.14799041326436488)

Notice that the concern for reverse is the type of the gradient they receive, not the forward argument. Since map on exotic types makes an Array (itself perhaps a mistake) this gradient should again be an array.

could fix the triangular and symmetric cases without type piracy.

Sure. Adding methods to the existing _tryreverse would be one way.