Open jw3126 opened 3 years ago
You are absolutely right - there's no way to represent a single primitive f(args...)
as a tape, at least as a tape different from the one for args -> f(args...)
. I see several options here:
I lean towards the last option since it's unlikely somebody will trace primitives not from REPL, and warnings in REPL are usually fine. I will let this idea to mature for the next couple of days though.
I think that the tape args -> f(args...)
option sounds natural. One way to think about a tape is that it is a list of all primitive calls that occur. If the entry point was already primitive, then it is just this one primitve call.
I also expected that if you have function f(x); g(x) end
then trace(f,x)
and trace(g,x)
would be the same. This again would be consistent with tracing a primitive call returning the tape with just that call. What drawbacks do you see with this?
Also one could add to the list: 1b. Throw an error by default, but that error can be disabled with a keyword allowing tracing into a primitive like currently. Generally I usually favor an error that must be explicitly disabled over a warning.
I also expected that if you have function f(x); g(x) end then trace(f,x) and trace(g,x) would be the same.
trace()
already works like this:
julia> g(x) = 2x
g (generic function with 1 method)
julia> f(x) = g(x)
f (generic function with 1 method)
julia> trace(f, 1.0)[2]
Tape{Dict{Any, Any}}
inp %1::typeof(f)
inp %2::Float64
%3 = *(2, %2)::Float64
julia> trace(g, 1.0)[2]
Tape{Dict{Any, Any}}
inp %1::typeof(g)
inp %2::Float64
%3 = *(2, %2)::Float64
grad()
behaves similarly with the exception for caching.
One way to think about a tape is that it is a list of all primitive calls that occur.
The first input to a tape is usually an object being called. In case of args -> f(args...)
this object is an anonymous function which is fine. In case of a primitive it's unclear what should we put there instead.
The most straightforward way is to wrap the primitive into an anonymous function, but it will break an assumption that tape[V(1)].fn == f
which may be useful for introspection and downstream transformations. It will also break on closures/callable structs.
The same applies to skipping the first argument altogether.
Putting the primitive itself as the first input also sounds weird - it will look like a recursive function which it's not.
On the other hand, trying to trace a primitive function doesn't seem to be a big use case, raising an error or warning sounds like a reasonable solution for me, at least until we hit a real case where it's not enough.
Perhaps another entry for the list: It seems that grad currently deoes not work with LinearAlgebra.Adjoint
julia> A = rand(100, 100)
julia> x = rand(100)
julia> Yota.grad(x -> 0.5 * x' * A * x, x)
ERROR: LoadError: No deriative rule found for op %8 = *(0.5, %7, %2)::Float64, try defining it using
ChainRulesCore.rrule(::typeof(*), ::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}}, ::Vector{Float64}) = ...
Thanks! I posted the corresponding issue in JuliaDiff/ChainRules.jl#589 as other AD systems may benefit from such a rule too.
Here is a list of some operations that did not work for me. I wonder about the errors that involve ChainRules in their message? For instance, in the sum example, I guess we are tracing too deep into the sum implementation. E.g. there exists a more high level sum rule: https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl#L9
@dfdx
Maybe theI realize thattrace
used forgradtape
should have an is_primitive that checks if the signature is covered by an rrule?Yota.is_primitive
!==Ghost.is_primitive
and there is already such a rule. I think the issue is what should happen when one starts tracing with a call that is already primitive. Not obvious whats the best design. Currently, such a call is entered anyway, this is why e.g.sum([1.0])
fails.