JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
254 stars 62 forks source link

Trait-based dispatch #247

Open YingboMa opened 3 years ago

YingboMa commented 3 years ago

Using the standard Julia type hierarchy limits people to do <: Number when defining a type, but this is not generally doable for generic codes. I wonder if we can use isreal, iscomplex, isscalar, ismatrix, etc to do the dispatch.

YingboMa commented 3 years ago

One common example is a wrapper type that tries to simulate Var{T} <: T. It's impossible to use ChainRules.jl on that with the current approach.

oxinabox commented 3 years ago

I see what you are saying. With noting to be clear: If the primal function uses trait dispatch, then the rule must also. And if the primal function doesn't, then you need 1 rule method per primal method. That's actually almost always true, needing one rule method per primal method.

Though a question is: why do you need rules for Var{T} that acts like T? In most code I have seen/written such objects delegate a call of foo(x::Var{T}) to a call of foo(::T). Either on a field or on a value it computes. Potentially, with some pre or post processing. Thus, if you have a rule for foo(::T), then your AD should easily and efficiently (maybe constant folded even) take care of making sure foo(::Var{T}) works.

YingboMa commented 3 years ago

For symbolic manipulations, we need to have Var{T} that doesn't call into foo(::T) directly, since Var{T} doesn't know the actual value, but at the same time, Var{T} needs to behave like T to trace through the program.

lamorton commented 3 years ago

Here's a potential use-case. ODESolutions allow indexing using a symbolic variable name. Internally, solution_interface.jl specializes getindex to either look up or compute the result as necessary.

Whether a variable is considered symbolic is defined by the function issymbollike, not by a particular type. It would be nice to dispatch on this, because at present I'm getting method ambiguities.