compintell / Mooncake.jl

https://compintell.github.io/Mooncake.jl/
MIT License
140 stars 7 forks source link

More Friendly Types at the Interface Level #393

Open willtebbutt opened 2 days ago

willtebbutt commented 2 days ago

This issue https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/642 makes me wonder whether we need a systematic approach to translating between primal types and tangent types at the interface level.

For example, while users probably want to represent the tangent of a, SArray with another SArray, rather than a Tangent, Mooncake requires that users provide a Tangent.

I think we can probably define sensible translation functionality between primals and tangents which makes some choices around how to handle non-differentiable fields, but which works quite generically. This function would be something like

translate_to_tangent(t::IEEEFloat) = t
translate_to_tangent(t::VariousIntegerTypes) = NoTangent()
translate_to_tangent(t::Array{<:IEEEFloat}) = t
translate_to_tangent(t::Array) = map(translate_to_tangent, t)
function translate_to_tangent(t::P) where {P}
    isprimitivetype(P) && throw(error("need a translation rule"))
    return # recursively transform into tangent_type(P)
end

This would have the effect of, for example, dropping any non-differentiable fields.

On the way back, we could do a similar thing, but would need to pick a placeholder value for any non-differentiable fields. Not all types have a well-defined zero value (e.g. Strings and Symbols), so it might just make sense to make the conversion the other way require that you pass in the primal value, and we just copy its fields. For example

translate_to_primal(::P, t::P) where {P<:IEEEFloat} = t
translate_to_primal(p::P, t) where {P<:VariousIntegerTypes} = p
translate_to_primal(p::Array{P}, t::Array{P}) where {P<:IEEEFloat} = p
translate_to_primal(p::Array, t::Array) = map(translate_to_primal, p, t)
function translate_to_primal(p::P, t) where {P}
    # same idea as translate_to_tangent, but in the other direction
end

We could then e.g. shove this in value_and_pullback or whatever, so that users get "nice" types. We would want to ensure that there is a sensible way to opt-out of this translation, of course.

gdalle commented 1 day ago

I think this would be very helpful for users, because what most of them want is something that "looks like the primal", not the Mooncake Tangent which could be considered an implementation detail