Open MasonProtter opened 4 years ago
Yes, Shashi and I are looking into IRTools.jl and source-to-source transformation. But before that, we want to solve #33 first to make ForwardDiff2 functional.
I still have a bunch of concerns about this kind of approach, but since you're going to figure it out anyway, I may as well point out that this is really easy in IRTools. You always control exactly what methods you overload, when and how you recurse etc. Here's a proto-AD that does what you want in 18 lines (15 of which are standard dual number stuff):
using IRTools
using IRTools: @dynamo, IR
import Base: *, +
struct Dual{T}
x::T
ϵ::T
end
unwrap(T) = T
unwrap(::Type{Dual{T}}) where T = T
a::Dual * b::Dual = Dual(a.x * b.x, a.ϵ * b.x + b.ϵ * a.x)
a::Dual + b::Dual = Dual(a.x + b.x, a.ϵ + b.ϵ)
@dynamo function diff(f, args...)
IR(f, unwrap.(args...))
end
julia> f(x::Int) = x*x;
julia> f(x::Float64) = x+x;
julia> diff(f, Dual(5, 1))
Dual{Int64}(25, 10)
julia> diff(f, Dual(5.0, 1.0))
Dual{Float64}(10.0, 2.0)
This has been touched on in #24 and https://github.com/jrevels/Cassette.jl/pull/157, but it should probably have its own dedicated issue for clarity.
Last I heard the plan was to support this using the above referenced Cassette PR, but that seems to not be very well received by the repo maintainers.
An alternative way to support this would be to use this IRTools Dynamo example which works already: