Open simsurace opened 1 year ago
Is there a minimal example?
When calling map
with a stateful function, pre-1376 behaviour was just to blindly apply the pullbacks in the wrong order, except on a vector, which sometimes gave silently wrong answers. Present behaviour is to error if reverse
doesn't work.
I believe it boils down to
julia> using LinearAlgebra
julia> X = UpperTriangular(randn(3, 3))
3×3 UpperTriangular{Float64, Matrix{Float64}}:
-0.211524 -1.40637 -0.315358
⋅ 0.933569 -0.797453
⋅ ⋅ -0.339659
julia> reverse(X)
ERROR: ArgumentError: cannot set index in the lower triangular part (3, 2) of an UpperTriangular matrix to a nonzero value (-1.4063656013863817)
Stacktrace:
[1] setindex!
@ ~/.julia/juliaup/julia-1.8.5+0.x64.apple.darwin14/share/julia/stdlib/v1.8/LinearAlgebra/src/triangular.jl:233 [inlined]
[2] _setindex!
@ ./abstractarray.jl:1374 [inlined]
[3] setindex!
@ ./abstractarray.jl:1344 [inlined]
[4] _reverse!(A::UpperTriangular{Float64, Matrix{Float64}}, dims::Tuple{Int64, Int64})
@ Base ./arraymath.jl:89
[5] _reverse!
@ ./arraymath.jl:71 [inlined]
[6] reverse!(A::UpperTriangular{Float64, Matrix{Float64}}; dims::Function)
@ Base ./arraymath.jl:70
[7] _reverse(A::UpperTriangular{Float64, Matrix{Float64}}, dims::Function)
@ Base ./arraymath.jl:60
[8] #reverse#247
@ ./arraymath.jl:59 [inlined]
[9] reverse(A::UpperTriangular{Float64, Matrix{Float64}})
@ Base ./arraymath.jl:59
[10] top-level scope
@ REPL[3]:1
So if you hit reverse
in the reverse-pass of AD and you've mapped over an UpperTriangular
or got an UpperTriangular
cotangent, you can't reverse.
Present behaviour is to error if reverse doesn't work.
This is breaking in all cases where people were (correctly) using non-mutating functions (using a mutating function with map
is a dubious choice in any case, because the semantics of map
don't guarantee ordering) and for whatever reason reverse
doesn't work on their type. Should this not have prompted at least a minor version bump?
I'm trying to come up with one that is more minimal, but I can't right now. The call of reverse
on UpperTriangular
is definitely at play, but I don't know where that UpperTriangular
was coming from.
Btw. something like
Zygote.gradient(m -> reduce(+, m), UpperTriangular(rand(3, 3)))
produces the same type of error, but is maybe unrelated because it doesn't work on 0.6.55 either.
Should this not have prompted at least a minor version bump?
That would've required tests to catch it. Historically the whole map
path has been woefully undertested, which would be bad for most rules but especially so for complex, higher-order ones like map
. Oh, and the type signature for the rule was way too broad: over-promise, under-deliver and all that. Writing historical wrongs like this proactively requires time and effort that Zygote does not have at its disposal in the present day, so the best we can do is fix them when they come up (like here).
I understand your position on this, but you've got to appreciate that it's infuriating to have patch releases in Zygote breaking my code, when the code I've written is entirely legit (it would be a different matter if I were depending on buggy behaviour).
That reverse(UpperTriangular(rand(3,3)))
fails is arguably a bug in LinearAlgebra, which would be trivial to solve there, although not trivial to backport. We could add a special path for that here too, but it needs a test otherwise we are just guessing.
A Zygote MWE has to be a gradient
which uses this path, and got the correct answer before. I have yet to invent one, e.g. gradient((x,y) -> sum(abs2, map(z -> z/y, x)), (UpperTriangular([1 2; 3 4])), 5)[1]
works fine. That's an impure but not stateful function, and I thought Zygote only had paths for pure (no reverse) and impure (reverse). So these statements are not correct:
So if you hit reverse in the reverse-pass of AD and you've mapped over an UpperTriangular
This is breaking in all cases where people were (correctly) using non-mutating functions
Edit: In my example, map
seems to be producing a Matrix on the forward pass, hence the gradient it receives will not be an UpperTriangular. ... Edit' like so:
julia> let y = 2
pullback(x -> map(z -> z/y, x), UpperTriangular([1 2; 3 4]))[1]
end
2×2 Matrix{Float64}:
0.5 1.0
0.0 2.0
julia> map(x -> x/2, UpperTriangular([1 2; 3 4]))
2×2 UpperTriangular{Float64, Matrix{Float64}}:
0.5 1.0
⋅ 2.0
So these statements are not correct:
That's fair. I should have been more specific in my wording:
This is breaking in all cases where people were using functions which aren't singleton types and whose application order doesn't matter, and for whatever reason reverse doesn't work on the cotangent passed into the pullback associated with map
.
Here's a MWE:
using Zygote, LinearAlgebra
struct Foo x::Float64 end
(f::Foo)(x) = f.x * x
x = randn(5, 5)
out, pb = Zygote.pullback(x -> map(Foo(5.0), x), x)
pb(UpperTriangular(randn(5, 5)))
Something along the lines of the following will yield an UpperTriangular
cotangent:
only(Zygote.gradient(x -> sum(cholesky(Symmetric(x)).U), diagm(ones(5))))
Ok. Those could be made into a test.
Can you make an upstream issue about reverse(::UpperTriangular)
? Until that's fixed Zygote could have a method something like _tryreverse(::typeof(map), x::UpperTriangular) = LowerTriangular(reverse(parent(x)))
Can you make an upstream issue about reverse(::UpperTriangular)?
I'm actually not sure what aspect of the behaviour of reverse(::UpperTriangular)
is incorrect as I'm not totally familiar with the semantics of reverse
. Would you mind making it?
Until that's fixed Zygote could have a method something like _tryreverse(::typeof(map), x::UpperTriangular) = LowerTriangular(reverse(parent(x)))
That would certainly fix the immediate problem.
Clearly these should be ==
, and arguably the second should be LowerTriangular
:
julia> reverse(UpperTriangular([1 2 3; 4 5 6; 7 8 9]) |> collect)
3×3 Matrix{Int64}:
9 0 0
6 5 0
3 2 1
julia> reverse(UpperTriangular([1 2 3; 4 5 6; 7 8 9]))
ERROR: ArgumentError: cannot set index in the lower triangular part (3, 2) of an UpperTriangular matrix to a nonzero value (2)
Another idea is to use applicable
. Something like _tryreverse(x) = applicable(reverse, x) ? reverse(x) : reverse(collect(x))
. It will not give the most natural cotangent type in many cases, but given both map
and Zygote aren't super strict about maintaining input types that's unlikely to be an issue. This also doesn't preclude more specialized methods such as the UpperTriangular
overload above (which otherwise would pass the applicable
test as a false positive).
Another idea is to use
applicable
. Something like_tryreverse(x) = applicable(reverse, x) ? reverse(x) : reverse(collect(x))
. It will not give the most natural cotangent type in many cases, but given bothmap
and Zygote aren't super strict about maintaining input types that's unlikely to be an issue. This also doesn't preclude more specialized methods such as theUpperTriangular
overload above (which otherwise would pass theapplicable
test as a false positive).
I like this generic fallback because the issue will probably occur with most matrix wrappers.
the issue will probably occur with most matrix wrappers
reverse(Diagonal([1,2,3]))
is fine, reverse(Symmetric([1 2; 3 4]))
is an error, again an upstream bug really. For user types, an error which clearly indicates that your type needs to support reverse
might be better than a silent copy / move-to-CPU.
applicable
test as a false positive
Notice this -- applicable
sees that reverse(::AbstractArray)
exists, but the failure is later. So this test would only apply to stranger types, but what exactly I'm not sure. If you get a NamedTuple here (structural gradient, for some iterator perhaps?) then neither path will go well.
That's correct. I was thinking of map
on iterators and other exotic types. Part of this is that map
itself advertises itself as too general (thought at least the fallback works, unlike reverse
...). Looking back through old issues I agree with @willtebbutt's argument in https://github.com/FluxML/Zygote.jl/issues/646 that the rule should've been defined constructively for known well-behaving types. In the meantime, maybe we can throw a runtime error approximating ChainRules' Notimplemented
for anything which doesn't support reverse
?
the issue will probably occur with most matrix wrappers
reverse(Diagonal([1,2,3]))
is fine,reverse(Symmetric([1 2; 3 4]))
is an error, again an upstream bug really. For user types, an error which clearly indicates that your type needs to supportreverse
might be better than a silent copy / move-to-CPU.
applicable
test as a false positiveNotice this --
applicable
sees thatreverse(::AbstractArray)
exists, but the failure is later. So this test would only apply to stranger types, but what exactly I'm not sure. If you get a NamedTuple here (structural gradient, for some iterator perhaps?) then neither path will go well.
Thanks, I did not appreciate these points.
But does this being classified as an upstream bug mean that we have to wait for another Julia minor release? Shouldn't Zygote take into account that reverse
does not perfectly fill this role, and use a custom function that can be appropriately overloaded in all known cases for types in Base and the standard library, and only forward to reverse for those cases in which it is known to work?
If we made the rule use _reverse
, we could fix the triangular and symmetric cases without type piracy.
map itself advertises itself as too general
It is nice that things like this work:
julia> gradient(x -> sum(abs2, @showtype map(Base.splat(/), @showtype zip(x, 1:3))), [4,5,6,7])
typeof(fwd) = Base.Iterators.Zip{Tuple{Vector{Int64}, UnitRange{Int64}}} # rule for zip does not collect this
typeof(fwd) = Vector{Float64} # but map(f, ::Zip) just makes an Array. Possibly shared with collect(::Generator)?
([8.0, 2.5, 1.3333333333333333, 0.0],)
julia> gradient((x,y) -> sum(abs2, map(((a,b),) -> a/(b+y), zip(x, 1:3))), [4,5,6,7], 8)
([0.09876543209876543, 0.1, 0.09917355371900825, 0.0], -0.14799041326436488)
Notice that the concern for reverse
is the type of the gradient they receive, not the forward argument. Since map
on exotic types makes an Array (itself perhaps a mistake) this gradient should again be an array.
could fix the triangular and symmetric cases without type piracy.
Sure. Adding methods to the existing _tryreverse
would be one way.
1376 broke this for example (this used to work on v0.6.55):
https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/issues/355
Could the change be reverted (or maybe there is an obvious fix) and a new patch be released?