JuliaDiff / DifferentiationInterface.jl

An interface to various automatic differentiation backends in Julia.
https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface
MIT License
202 stars 15 forks source link

Mooncake Backend doesn't handle functions with StaticArrays output #642

Open jmurphy6895 opened 3 days ago

jmurphy6895 commented 3 days ago

If a function returns a static array or vector, the AutoMooncake backend errors

using StaticArrays
using DifferentiationInterface
using ForwardDiff
using Mooncake

function MWE(x::AbstractVector)

    z = SVector{3}(x.^2)

    return z

end

test = rand(3)

f_ad, df_ad = value_and_jacobian(
    MWE,
    AutoMooncake(; config=nothing),
    test
)

f_ad, df_ad = value_and_jacobian(
    MWE,
    AutoForwardDiff(),
    test
)

Gives the error

ERROR: MethodError: no method matching copyto!(::Mooncake.Tangent{@NamedTuple{data::Tuple{Float64, Float64, Float64}}}, ::SVector{3, Float64}) The function copyto! exists, but no method is defined for this combination of argument types.

Closest candidates are: copyto!(::IndexStyle, ::AbstractArray, ::IndexStyle, ::AbstractArray) @ Base abstractarray.jl:1064 copyto!(::Zygote.Buffer, ::Any) @ Zygote C:\Users\jmurp.julia\packages\Zygote\nyzjS\src\tools\buffer.jl:54 copyto!(::PermutedDimsArray, ::AbstractArray) @ Base permuteddimsarray.jl:295 ...

Stacktrace: [1] copyto!!(dst::Mooncake.Tangent{@NamedTuple{data::Tuple{Float64, Float64, Float64}}}, src::SVector{3, Float64}) @ DifferentiationInterfaceMooncakeExt C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\ext\DifferentiationInterfaceMooncakeExt\DifferentiationInterfaceMooncakeExt.jl:29 [2] value_and_pullback(::Function, ::DifferentiationInterfaceMooncakeExt.MooncakeOneArgPullbackPrep{…}, ::AutoMooncake{…}, ::Vector{…}, ::Tuple{…}) @ DifferentiationInterfaceMooncakeExt C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\ext\DifferentiationInterfaceMooncakeExt\onearg.jl:35 [3] prepare_pullback(::Function, ::AutoMooncake{Nothing}, ::Vector{Float64}, ::Tuple{SVector{3, Float64}}) @ DifferentiationInterfaceMooncakeExt C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\ext\DifferentiationInterfaceMooncakeExt\onearg.jl:22 [4] _prepare_jacobian_aux(::DifferentiationInterface.PushforwardSlow, ::DifferentiationInterface.BatchSizeSettings{…}, ::SVector{…}, ::Tuple{…}, ::AutoMooncake{…}, ::Vector{…}) @ DifferentiationInterface C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\src\first_order\jacobian.jl:167 [5] prepare_jacobian(::typeof(MWE), ::AutoMooncake{Nothing}, ::Vector{Float64}) @ DifferentiationInterface C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\src\first_order\jacobian.jl:108 [6] value_and_jacobian(::typeof(MWE), ::AutoMooncake{Nothing}, ::Vector{Float64}) @ DifferentiationInterface C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\src\fallbacks\no_prep.jl:60 [7] top-level scope @ c:\Users\jmurp.julia\dev\SatelliteToolboxGravityModels\test\differentiability.jl:69 Some type information was truncated. Use show(err) to see complete types.

gdalle commented 3 days ago

Thanks for reporting this! Can you try it out with the branch from #643 ?

jmurphy6895 commented 3 days ago

Thanks for the quick reply! It looks like it's still giving the same error on my end

gdalle commented 3 days ago

Oh right, I had misread the error. Modulo my hotfix, I now think this happens because DI expects an array but Mooncake wraps it into a Tangent. @willtebbutt any idea how we should handle this?

willtebbutt commented 3 days ago

So this looks to me like it's happening on https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/cc8818a2bb0fb3dab2abf29ba213f89213a8613a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl#L35 , when we prepare the DI-friendly gradient to be passed into the reverse-pass of Mooncake.

If I'm not mistaken, we just need to a add a method of copy_to!! which knows how to translate a static array into a Tangent. There's enough information in the call to do this. It would just be something like

function copyto!!(dst::T, src::SVector{3, Float64}) where {T<:Mooncake.Tangent{@NamedTuple{data::Tuple{Float64, Float64, Float64}}}}
    return T((data = getfield(src, :data), ))
end

in this case.

We'll need translation rules like this for most types so, thinking forwards to more general types, we'll never have a complete solution to this translation problem unless DI places some restrictions on the set of types that users are permitted to work with. My understanding is that you're not keen to restrict users in this way, so probably the best thing to do is to define a catch-all method of copyto!! which throws a (more informative version of) an error message saying something like "we didn't expect to you pass this type, so we don't have a conversion rule for it. Please open an issue."

That being said, it might be that we can do something more general which says "the thing I'm trying to copy_to!! to is a Tangent, therefore I just need to recursively pull out the fields of src and build NamedTuples out of them", and do a similar thing for MutableTangent but in-place. This might be functionality that I should provide as part of the tangent interface in Mooncake.

Either way, since static arrays are something people are quite interested in currently, I would suggest just adding a conversion rule for this case, and punting the more general fix down the line -- I'll open an issue on Mooncake which references this issue.

Additionally, note that the thing that Mooncake will return in this instance is another Tangent (assuming that the argument to the function being differentiated is itself a static array), which is probably not what users want. It might be worth thinking a bit about whether we want to apply some specific translation rules to make the types "more user friendly" in a uniform way. Or I could think a bit about how to do something predictable on Mooncake's end.