Open oscardssmith opened 2 years ago
I think you're asking whether there's a scheme for chunked reverse mode. There is not: at present (co)tangents match the size of the primal. https://github.com/JuliaDiff/ChainRulesCore.jl/issues/92 has some discussion, see also https://github.com/JuliaDiff/Diffractor.jl/pull/54.
Edit: most rules will enforce this via projection:
julia> x = [1,2,3]; # vector primal
julia> ProjectTo(x)([4;5;6;;]) # allows 1-column matrix, converts to vector
3-element Vector{Float64}:
4.0
5.0
6.0
julia> ProjectTo(x)([4 5 6]) # does not allow worse shapes
ERROR: DimensionMismatch: variable with size(x) == (3,) cannot have a gradient with size(dx) == (1, 3)
For now the current status should be clearly documented, perhaps at these pages:
https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tangents.html
https://juliadiff.org/ChainRulesCore.jl/dev/maths/propagators.html
This is a general issue, but for a specific incarnation, https://github.com/JuliaDiff/ChainRules.jl/blob/8073c7c4638bdd46f4e822d2ab72423c051c5e4b/src/rulesets/Base/array.jl#L40
This rule implicitly assumes that
ȳ
is aVector
, but if you are taking a jacobian, it will be aMatrix
in which case, it should beSimilar problems also exist for the
getindex
rrules, and I'm sure there are a bunch of other similar cases. Is there a good general solution to this?