dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

list of operations that grad does not work with #93

Open jw3126 opened 3 years ago

jw3126 commented 3 years ago

Here is a list of some operations that did not work for me. I wonder about the errors that involve ChainRules in their message? For instance, in the sum example, I guess we are tracing too deep into the sum implementation. E.g. there exists a more high level sum rule: https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl#L9

@dfdx Maybe the trace used for gradtape should have an is_primitive that checks if the signature is covered by an rrule? I realize that Yota.is_primitive !== Ghost.is_primitive and there is already such a rule. I think the issue is what should happen when one starts tracing with a call that is already primitive. Not obvious whats the best design. Currently, such a call is entered anyway, this is why e.g. sum([1.0]) fails.

################################################################################
Yota.gradtape(sum, [1.0])
fails
No deriative rule found for op %42 = mapreduce(identity, add_sum, %2)::Float64, try defining it us
ing ChainRules.rrule(::typeof(mapreduce), ::typeof(identity), ::typeof(Base.add_sum), ::Vector{Flo
at64}) = ...
################################################################################
Yota.gradtape(sum, abs2, [1.0])
fails
No deriative rule found for op %30 = mapreduce(%2, add_sum, %3)::Float64, try defining it using Ch
ainRules.rrule(::typeof(mapreduce), ::typeof(abs2), ::typeof(Base.add_sum), ::Vector{Float64}) = .
..
################################################################################
Yota.gradtape(identity, 1.0)
fails
MethodError: no method matching call_signature(::Tape{Yota.GradCtx}, ::Ghost.Input)
Closest candidates are:
  call_signature(::Tape, ::Ghost.Call) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:516
################################################################################
Yota.gradtape(sin, 1.0)
fails
MethodError: Cannot `convert` an object of type Float64 to an object of type Ghost.Variable
Closest candidates are:
  convert(::Type{T}, ::T) where T at essentials.jl:205
  Ghost.Variable(::Any, ::Any) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:22
################################################################################
Yota.gradtape(*, 1.0)
fails
MethodError: no method matching call_signature(::Tape{Yota.GradCtx}, ::Ghost.Input)
Closest candidates are:
  call_signature(::Tape, ::Ghost.Call) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:516
################################################################################
Yota.gradtape(*, 1.0, 2.0)
fails
No deriative rule found for op %4 = mul_float(%2, %3)::Float64, try defining it using ChainRules.r
rule(::Core.IntrinsicFunction, ::Float64, ::Float64) = ...
dfdx commented 3 years ago

You are absolutely right - there's no way to represent a single primitive f(args...) as a tape, at least as a tape different from the one for args -> f(args...). I see several options here:

  1. Leave it as is, letting people trace the code of primitives even if sometimes it will confuse them.
  2. Forbid tracing the primitives. But what if it is just what somebody wanted to do?
  3. Show a warning explaining that it's probably not what a user wants, but letting them do it anyway.

I lean towards the last option since it's unlikely somebody will trace primitives not from REPL, and warnings in REPL are usually fine. I will let this idea to mature for the next couple of days though.

jw3126 commented 3 years ago

I think that the tape args -> f(args...) option sounds natural. One way to think about a tape is that it is a list of all primitive calls that occur. If the entry point was already primitive, then it is just this one primitve call. I also expected that if you have function f(x); g(x) end then trace(f,x) and trace(g,x) would be the same. This again would be consistent with tracing a primitive call returning the tape with just that call. What drawbacks do you see with this?

Also one could add to the list: 1b. Throw an error by default, but that error can be disabled with a keyword allowing tracing into a primitive like currently. Generally I usually favor an error that must be explicitly disabled over a warning.

dfdx commented 3 years ago

I also expected that if you have function f(x); g(x) end then trace(f,x) and trace(g,x) would be the same.

trace() already works like this:

julia>     g(x) = 2x
g (generic function with 1 method)

julia>     f(x) = g(x)
f (generic function with 1 method)

julia>     trace(f, 1.0)[2]
Tape{Dict{Any, Any}}
  inp %1::typeof(f)
  inp %2::Float64
  %3 = *(2, %2)::Float64

julia>     trace(g, 1.0)[2]
Tape{Dict{Any, Any}}
  inp %1::typeof(g)
  inp %2::Float64
  %3 = *(2, %2)::Float64

grad() behaves similarly with the exception for caching.

One way to think about a tape is that it is a list of all primitive calls that occur.

The first input to a tape is usually an object being called. In case of args -> f(args...) this object is an anonymous function which is fine. In case of a primitive it's unclear what should we put there instead.

The most straightforward way is to wrap the primitive into an anonymous function, but it will break an assumption that tape[V(1)].fn == f which may be useful for introspection and downstream transformations. It will also break on closures/callable structs.

The same applies to skipping the first argument altogether.

Putting the primitive itself as the first input also sounds weird - it will look like a recursive function which it's not.

On the other hand, trying to trace a primitive function doesn't seem to be a big use case, raising an error or warning sounds like a reasonable solution for me, at least until we hit a real case where it's not enough.

lassepe commented 2 years ago

Perhaps another entry for the list: It seems that grad currently deoes not work with LinearAlgebra.Adjoint

julia> A = rand(100, 100)
julia> x = rand(100)
julia> Yota.grad(x -> 0.5 * x' * A * x, x)
ERROR: LoadError: No deriative rule found for op %8 = *(0.5, %7, %2)::Float64, try defining it using 

    ChainRulesCore.rrule(::typeof(*), ::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}}, ::Vector{Float64}) = ...
dfdx commented 2 years ago

Thanks! I posted the corresponding issue in JuliaDiff/ChainRules.jl#589 as other AD systems may benefit from such a rule too.