JuliaTeachingCTU / Scientific-Programming-in-Julia

Repository for B0M36SPJ
https://juliateachingctu.github.io/Scientific-Programming-in-Julia/dev/
MIT License
76 stars 12 forks source link

Adding ChainRules and nitpicks around it #115

Open pevnak opened 1 year ago

pevnak commented 1 year ago

Copied from Zulip

Maarten: What - exactly - is ProjectTo supposed to do (Chainrules, Zygote)? According to the documentation, it can be used:

project_A = ProjectTo(A) can be used (outside the pullback) to extract an object that knows how to project onto the type of A

And yet the zygote source implements:

(project::ProjectTo{AbstractArray})(dx::Tangent) = dx

ChainRules too, makes absolutely no attempt at reconstructing any subtype of AbstractArray. ProjectTo(A) therefore cannot be used to project on something of type A. In fact, the output type of ProjectTo can be almost anything, from Tangent to arrays.

Michael Abbott: So it's a bit complicated, and there are some fallback paths like the method you point at. The most obvious thing it does is ProjectTo(1)(2+3im), disallowing a real number from having a complex gradient.

Maarten: Is there a mechanism to convert a tangent type to something like the primal type? I use zygote, which eagerly loses the type of its tangents, so the tangent type's I get in are Tangent{Any,...}. Canonicalize doesn't work, projectTo doesn't seem to be doing this.

Is my only option to instantiate something like the primal type and iterate over the fields that are defined in the Tangent{Any,...} structure, copying them over?

Maarten: I am also fine with documentation being complicated, but I just can't seem to find what ProjectTo exactly should be doing. I'm now trying to define my own projection rules to get around the default fallbacks.

Michael Abbott: So it is supposed to convert back from structural to natural where it can. But Zygote forgets things...

julia> ProjectTo(2+3im)(Tangent{Complex}(re=4)) 4.0 + 0.0im

julia> ProjectTo(2+3im)(Tangent{Any}(re=4)) # converted from Zygote's NamedTuple, don't recall why this fdoesn't work ERROR: MethodError: no method matching

julia> p = ProjectTo(Diagonal(Float32[1,2,3]));

julia> p(rand(3,3)) # just like real numbers really, subspace 3×3 Diagonal{Float32, Vector{Float32}}: 0.485392 ⋅ ⋅ ⋅ 0.996753 ⋅ ⋅ ⋅ 0.410348

julia> p(Tangent{Diagonal}(diag = [4,5,6])) # this there's a stalled PR which converts back Tangent{Diagonal}(diag = [4, 5, 6],)

Michael Abbott: Sorry the documentation is not great, I thought we had more examples in there. There's supposed to be a paper being written...

Maarten: But doesn't it make more sense to just error whenever it fails to Project? Then I could define the necessary rules, and have everything running just fine. Now I've been digging through the ProjectTo code, providing more specialized cases, in the hopes of intercepting these fallbacks

Michael Abbott:

instantiate something like the primal type and iterate over the fields that are defined in the Tangent{Any,...} structure, copying them over?

You can do this. If you know it to be correct, you can define a method of ProjectTo which will do it everywhere for you. But it isn't done in general as it isn't always correct.

Michael Abbott: I think we tried stricter versions and it produced so many errors that we couldn't inflict this on everyone. Gave up and made it fix things it knows how to fix.

Michael Abbott: BTW the easy example where re-constructing the same struct from the tangent's fields goes badly wrong is:

julia> LinRange(1,2,3) 3-element LinRange{Float64, Int64}: 1.0, 1.5, 2.0

julia> dump(ans) LinRange{Float64, Int64} start: Float64 1.0 stop: Float64 2.0 len: Int64 3 lendiv: Int64 2

julia> gradient(x -> x.stop, LinRange(1,2,3)) (Tangent{LinRange{Float64, Int64}}(stop = 1.0,),)

Maarten: It makes indeed a lot of sense that the Tangent type should not be of the same type as the input type, and in those cases I would then expect ProjectTo(::LinRange)(::Tangent{LinRange}) to fail, as it cannot be projected to something of type LinRange.

Or is ProjectTo meant to project on a "canonical" Tangent type for a given primal? As in - the backward rule can get many tangent types, but by projecting them they are at least standardized to one canonical tangent type?

Michael Abbott: That one can in fact make another LinRange, but not the obvious one. In other cases there is no way.

Michael Abbott: I think the idea is that ProjectTo{T} ought to, for each T, decide whether it wants to make the gradient always structural (a Tangent) or always natural (e.g. an AbstractArray). Then enforce this choice.

Michael Abbott: It's not quite as narrow as one type. It allows ZeroTangent for anything. And more generally for any "natural" tangent allows more specific types, e.g. it allows Fill as the gradient for a Matrix. And it allows Dual as the gradient of a Float64.

Maarten: I think ProjectTo currently can be used to change such a FillArray to a dense Matrix, is this a mistake? I find it very hard to write consistent rules when the backward types can be almost everything

Maarten: And the ProjectTo problem with un-initialized arrays, would you solve this at the level of chainrules using something like

function ProjectTo(xs::AbstractArray) elements = map(eachindex(xs)) do ind isassigned(xs,ind) ? ProjectTo(xs[ind]) : NoTangent() end

if elements isa AbstractArray{<:ProjectTo{<:AbstractZero}}
    return ProjectTo{NoTangent}()  # short-circuit if all elements project to zero
else
    # Arrays of arrays come here, and will apply projectors individually:
    return ProjectTo{AbstractArray}(; elements=elements, axes=axes(xs))
end

end

Or should ProjectTo never have to be called on an unitialized array anyway?

Michael Abbott: I'm not sure FillArrays are ever a good idea honestly. There was some thought that you should allow say Diagonal as the gradient for UpperTriangular, or Matrix. Mathematically these both seem OK but in practice I am not sure the resulting uncertainty is worth the supposed efficiency. For sure it complicates accumulation of gradients.

Michael Abbott: The design would I think let you forbid these, or convert them back, but the present set of methods do not.

Michael Abbott: Re arrays of things other than Numbers, and also Tuple / NamedTuple, I'm not so sure we should have made this recursive. Rather than relying on projection within whatever rules produce the constituents correctly. But that's where we ended up in the push to CR 1.0.

Michael Abbott: I'm not sure what I think of this unassigned business. It still seems really weird to me to ask what the gradient is with respect to an undef quantity, I'm not sure why zero is a better answer than an error. But sometimes AD goes places where it's not really wanted, and the answer is sure to be discarded...

Maarten: In what context is it desirable to have a structural tangent type that doesnt know its primal type (so Tangent{Any,...}) ?

Michael Abbott: Never. But the bridge between Zygote & CR doesn't always know the type. Zygote just uses NamedTuples internally.

Michael Abbott: (Which might be better, simpler, but that's another topic)

Maarten: A combination of the undef-fix and the broadcast zygote fix now allows me to differentiate through my code :)

I still wish for a canonical "get_natural_tangent", which errors when it can't give me one, so that I can gradually opt in for different gradient types that may be faster

Maarten: wait actually - is ProjectTo(A) not simply the projection on the tangent space of A?

Michael Abbott: Yes. Maybe I'd say into to stress that the map accepts elements of a smaller space. With the big caveat that both Tangent and Array may be acceptable representations of the tangent space. I would like it to standardise completely (for any given T, either always Union{Tangent,Zero} or always Union{AbstractArray,Zero}) but bolting that onto existing code seemed too hard. https://github.com/JuliaDiff/ChainRulesCore.jl/pull/446 is my (stalled) attempt to at least convert most of the weird LinearAlgebra matrices from Tangent back to AbstractMatrix.

Maarten: I guess a way of converting between different tangent types would also solve this, and allow for specialization?

Michael Abbott: Yes. I think I worked out the mathematics of this at some point, and decided you need three functions, uncollect (makes Tangent), restrict (takes one array to another) & naturalise (Tangent to an AbstractArray). My experiments were at https://github.com/mcabbott/OddArrays.jl but we didn't finish writing the paper.