Open jmurphy6895 opened 3 days ago
Thanks for reporting this! Can you try it out with the branch from #643 ?
Thanks for the quick reply! It looks like it's still giving the same error on my end
Oh right, I had misread the error. Modulo my hotfix, I now think this happens because DI expects an array but Mooncake wraps it into a Tangent. @willtebbutt any idea how we should handle this?
So this looks to me like it's happening on https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/cc8818a2bb0fb3dab2abf29ba213f89213a8613a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl#L35 , when we prepare the DI-friendly gradient to be passed into the reverse-pass of Mooncake.
If I'm not mistaken, we just need to a add a method of copy_to!!
which knows how to translate a static array into a Tangent
. There's enough information in the call to do this. It would just be something like
function copyto!!(dst::T, src::SVector{3, Float64}) where {T<:Mooncake.Tangent{@NamedTuple{data::Tuple{Float64, Float64, Float64}}}}
return T((data = getfield(src, :data), ))
end
in this case.
We'll need translation rules like this for most types so, thinking forwards to more general types, we'll never have a complete solution to this translation problem unless DI places some restrictions on the set of types that users are permitted to work with. My understanding is that you're not keen to restrict users in this way, so probably the best thing to do is to define a catch-all method of copyto!!
which throws a (more informative version of) an error message saying something like "we didn't expect to you pass this type, so we don't have a conversion rule for it. Please open an issue."
That being said, it might be that we can do something more general which says "the thing I'm trying to copy_to!!
to is a Tangent
, therefore I just need to recursively pull out the fields of src
and build NamedTuple
s out of them", and do a similar thing for MutableTangent
but in-place. This might be functionality that I should provide as part of the tangent interface in Mooncake.
Either way, since static arrays are something people are quite interested in currently, I would suggest just adding a conversion rule for this case, and punting the more general fix down the line -- I'll open an issue on Mooncake which references this issue.
Additionally, note that the thing that Mooncake will return in this instance is another Tangent
(assuming that the argument to the function being differentiated is itself a static array), which is probably not what users want. It might be worth thinking a bit about whether we want to apply some specific translation rules to make the types "more user friendly" in a uniform way. Or I could think a bit about how to do something predictable on Mooncake's end.
If a function returns a static array or vector, the AutoMooncake backend errors
Gives the error