For example, while users probably want to represent the tangent of a, SArray with another SArray, rather than a Tangent, Mooncake requires that users provide a Tangent.
I think we can probably define sensible translation functionality between primals and tangents which makes some choices around how to handle non-differentiable fields, but which works quite generically. This function would be something like
translate_to_tangent(t::IEEEFloat) = t
translate_to_tangent(t::VariousIntegerTypes) = NoTangent()
translate_to_tangent(t::Array{<:IEEEFloat}) = t
translate_to_tangent(t::Array) = map(translate_to_tangent, t)
function translate_to_tangent(t::P) where {P}
isprimitivetype(P) && throw(error("need a translation rule"))
return # recursively transform into tangent_type(P)
end
This would have the effect of, for example, dropping any non-differentiable fields.
On the way back, we could do a similar thing, but would need to pick a placeholder value for any non-differentiable fields. Not all types have a well-defined zero value (e.g. Strings and Symbols), so it might just make sense to make the conversion the other way require that you pass in the primal value, and we just copy its fields. For example
translate_to_primal(::P, t::P) where {P<:IEEEFloat} = t
translate_to_primal(p::P, t) where {P<:VariousIntegerTypes} = p
translate_to_primal(p::Array{P}, t::Array{P}) where {P<:IEEEFloat} = p
translate_to_primal(p::Array, t::Array) = map(translate_to_primal, p, t)
function translate_to_primal(p::P, t) where {P}
# same idea as translate_to_tangent, but in the other direction
end
We could then e.g. shove this in value_and_pullback or whatever, so that users get "nice" types. We would want to ensure that there is a sensible way to opt-out of this translation, of course.
I think this would be very helpful for users, because what most of them want is something that "looks like the primal", not the Mooncake Tangent which could be considered an implementation detail
This issue https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/642 makes me wonder whether we need a systematic approach to translating between primal types and tangent types at the interface level.
For example, while users probably want to represent the tangent of a,
SArray
with anotherSArray
, rather than aTangent
, Mooncake requires that users provide aTangent
.I think we can probably define sensible translation functionality between primals and tangents which makes some choices around how to handle non-differentiable fields, but which works quite generically. This function would be something like
This would have the effect of, for example, dropping any non-differentiable fields.
On the way back, we could do a similar thing, but would need to pick a placeholder value for any non-differentiable fields. Not all types have a well-defined zero value (e.g.
String
s andSymbol
s), so it might just make sense to make the conversion the other way require that you pass in the primal value, and we just copy its fields. For exampleWe could then e.g. shove this in
value_and_pullback
or whatever, so that users get "nice" types. We would want to ensure that there is a sensible way to opt-out of this translation, of course.