JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
255 stars 62 forks source link

`rrules` do not support chunked mode #566

Open oscardssmith opened 2 years ago

oscardssmith commented 2 years ago

This is a general issue, but for a specific incarnation, https://github.com/JuliaDiff/ChainRules.jl/blob/8073c7c4638bdd46f4e822d2ab72423c051c5e4b/src/rulesets/Base/array.jl#L40

function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
    vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...)
    return Base.vect(X...), vect_pullback
end

This rule implicitly assumes that is a Vector, but if you are taking a jacobian, it will be a Matrix in which case, it should be

function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
    vect_pullback(ȳ) = (NoTangent(), ȳ...)
    return Base.vect(X...), vect_pullback
end

Similar problems also exist for the getindex rrules, and I'm sure there are a bunch of other similar cases. Is there a good general solution to this?

mcabbott commented 2 years ago

I think you're asking whether there's a scheme for chunked reverse mode. There is not: at present (co)tangents match the size of the primal. https://github.com/JuliaDiff/ChainRulesCore.jl/issues/92 has some discussion, see also https://github.com/JuliaDiff/Diffractor.jl/pull/54.

Edit: most rules will enforce this via projection:

julia> x = [1,2,3];  # vector primal

julia> ProjectTo(x)([4;5;6;;])  # allows 1-column matrix, converts to vector
3-element Vector{Float64}:
 4.0
 5.0
 6.0

julia> ProjectTo(x)([4 5 6])  # does not allow worse shapes
ERROR: DimensionMismatch: variable with size(x) == (3,) cannot have a gradient with size(dx) == (1, 3)
mcabbott commented 2 years ago

For now the current status should be clearly documented, perhaps at these pages:

https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tangents.html

https://juliadiff.org/ChainRulesCore.jl/dev/maths/propagators.html