JuliaDiff / ForwardDiff.jl

Forward Mode Automatic Differentiation for Julia
Other
888 stars 141 forks source link

Cannot compute gradient with `ArrayPartition` that holds containers of different types #706

Open bvdmitri opened 1 month ago

bvdmitri commented 1 month ago

ArrayPartition is a useful structure to concatenate arrays of different types. The type is defined in SciML/RecursiveArrayTools.jl

ArrayPartitions are also used in many places in SciML ecosystem, but also in other places like Manopt.jl. It appears, though, that if ArrayPartition references two containers, one of eltype is Float64 and another one is Int64, the gradient from ForwardDiff fails.

MWE is:

julia> using ForwardDiff, RecursiveArrayTools

julia> v = [ 0.0, 1 ]
2-element Vector{Float64}:
 0.0
 1.0

julia> f(v) = sum(v)
f (generic function with 1 method)

julia> ForwardDiff.gradient(f, [ 0.0, 1 ])
2-element Vector{Float64}:
 1.0
 1.0

julia> ForwardDiff.gradient(f, ArrayPartition([ 0.0 ], [ 1 ]))
ERROR: MethodError: no method matching ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}(::Int64, ::ForwardDiff.Partials{2, Float64})

Closest candidates are:
  ForwardDiff.Dual{T, V, N}(::Number) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:78
  ForwardDiff.Dual{T, V, N}(::Any) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:77
  ForwardDiff.Dual{T, V, N}(::V, ::ForwardDiff.Partials{N, V}) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:17

Stacktrace:
  [1] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
  [2] _broadcast_getindex
    @ ./broadcast.jl:682 [inlined]
  [3] getindex
    @ ./broadcast.jl:636 [inlined]
  [4] macro expansion
    @ ./broadcast.jl:1004 [inlined]
  [5] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [6] copyto!
    @ ./broadcast.jl:1003 [inlined]
  [7] copyto!
    @ ./broadcast.jl:956 [inlined]
  [8] materialize!
    @ ./broadcast.jl:914 [inlined]
  [9] materialize!
    @ ./broadcast.jl:911 [inlined]
 [10] seed!(duals::ArrayPartition{…}, x::ArrayPartition{…}, seeds::Tuple{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:52
 [11] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:23 [inlined]
mcabbott commented 1 month ago

I think the reason ForwardDiff is confused because this ArrayPartition declares itself to have Float64 elements, but in fact returns an Int sometimes:

julia> ap = ArrayPartition([ 0.0 ], [ 1 ])
([0.0], [1])

julia> size(ap), axes(ap), eltype(ap), supertype(typeof(ap))
((2,), (Base.OneTo(2),), Float64, AbstractVector{Float64})

julia> ap[1]  # no surprise
0.0

julia> ap[2]  # very surprising
1

julia> ap[1:2]  # here you get the expected eltype
2-element Vector{Float64}:
 0.0
 1.0

julia> ap[2,1:1]
1-element Vector{Float64}:
 1.0

The usual way to encode that elements of a vector have different types is to have an abstract eltype, which it seems ForwardDiff is able to handle.

(Note that the other example above constructs a Vector{Float64}, promoting to 1.0 when making the array.)

julia> x64 = [ 0.0, 1 ]  # this promotes on construction
2-element Vector{Float64}:
 0.0
 1.0

julia> ForwardDiff.gradient(f, x64)  # as in question
2-element Vector{Float64}:
 1.0
 1.0

julia> xabs = Real[ 0.0, 1 ]  # abstract eltype, could also use  xabs = Union{Float64, Int}[ 0.0, 1 ] 
2-element Vector{Real}:
 0.0
 1

julia> ForwardDiff.gradient(f, xabs)  # also OK, ForwardDiff not confused
2-element Vector{Float64}:
 1.0
 1.0

Fixing ArrayPartition to declare its eltype accurately would be the obvious fix here, and would probably avoid many other weird edge cases. (Or else fixing its getindex definition to convert to the declared eltype.) Although I'm sure there's going to be some reason that consistency is inconvenient for something.

It's possible that ForwardDiff could be made more robust to misleading signals. For instance making the ForwardDiff.Dual constructor called above promote its first argument might work here?

bvdmitri commented 1 month ago

Opened an issue in RecursiveArrayTools as well, though, I have a feeling that this behaviour might be by design.

For instance making the ForwardDiff.Dual constructor called above promote its first argument might work here?

For me that would be an obvious fix, that shouldn't break anything, right?

KristofferC commented 1 month ago

I have a feeling that this behaviour might be by design.

That just seems broken though. If fixing RecursiveArrayTools also fixes this then I don't think anything should be done here.