YingboMa / ForwardDiff2.jl

Other
52 stars 4 forks source link

Differentiating concretely typed function signatures #35

Open MasonProtter opened 4 years ago

MasonProtter commented 4 years ago

This has been touched on in #24 and https://github.com/jrevels/Cassette.jl/pull/157, but it should probably have its own dedicated issue for clarity.

julia> using ForwardDiff2: D
[ Info: Precompiling ForwardDiff2 [994df76e-a4c1-5e1f-bd5c-23b9b5303d4f]

julia> f(x::Float64) = x + 1
f (generic function with 1 method)

julia> D(f)(1)
ERROR: MethodError: no method matching f(::ForwardDiff2.Dual{ForwardDiff2.Tag{Nothing},Int64,Int64})
Closest candidates are:
  f(::Float64) at REPL[2]:1

Last I heard the plan was to support this using the above referenced Cassette PR, but that seems to not be very well received by the repo maintainers.

An alternative way to support this would be to use this IRTools Dynamo example which works already:

using IRTools: @dynamo, argument!, IR

concrete(::Type{Type{T}}) where {T} = T

function _sneaky_transform(f::Type, Types::Type{<:Tuple})
    ir = IR(f, Types.parameters...)
    argument!(ir, at = 2)
    return ir
end

@dynamo function sneakyinvoke(f, T, args...)
    Types = concrete(T)
    return _sneaky_transform(f, Types)
end
julia> sneakyinvoke(f, Tuple{Float64}, 1)
2

julia> D(x -> sneakyinvoke(f, Tuple{Float64}, x))(10.0)
1.0
YingboMa commented 4 years ago

Yes, Shashi and I are looking into IRTools.jl and source-to-source transformation. But before that, we want to solve #33 first to make ForwardDiff2 functional.

MikeInnes commented 4 years ago

I still have a bunch of concerns about this kind of approach, but since you're going to figure it out anyway, I may as well point out that this is really easy in IRTools. You always control exactly what methods you overload, when and how you recurse etc. Here's a proto-AD that does what you want in 18 lines (15 of which are standard dual number stuff):

using IRTools
using IRTools: @dynamo, IR
import Base: *, +

struct Dual{T}
  x::T
  ϵ::T
end

unwrap(T) = T
unwrap(::Type{Dual{T}}) where T = T

a::Dual * b::Dual = Dual(a.x * b.x, a.ϵ * b.x + b.ϵ * a.x)
a::Dual + b::Dual = Dual(a.x + b.x, a.ϵ + b.ϵ)

@dynamo function diff(f, args...)
  IR(f, unwrap.(args...))
end
julia> f(x::Int) = x*x;

julia> f(x::Float64) = x+x;

julia> diff(f, Dual(5, 1))
Dual{Int64}(25, 10)

julia> diff(f, Dual(5.0, 1.0))
Dual{Float64}(10.0, 2.0)