compintell / Mooncake.jl

https://compintell.github.io/Mooncake.jl/
MIT License
101 stars 3 forks source link

Some instances of invoke #206

Open willtebbutt opened 1 month ago

willtebbutt commented 1 month ago

invoke is an interesting built-in function. I came across a use of it when debugging a PR linked to #197 .

invoke recap

Recall that invoke behaves as follows:

julia> foo(x) = 5x
foo (generic function with 1 method)

julia> foo(x::Float64) = 6x
foo (generic function with 2 methods)

julia> foo(5.0)
30.0

julia> invoke(foo, Tuple{Any}, 5.0)
25.0

invoke lets us call whichever method of a function applies to the types provided in the second argument.

There is, in principle, no particular performance penalty associated to this, it just changes which Method of foo is specialised to Float64. In this case, a MethodInstance is produced for the Method returned by

which(Tuple{typeof(foo), Any})

which is specialised for Float64 argument types.

The IR generated for a call to invoke varies depending on a few different factors. I believe that all possible cases are considered below:

1. Statically-Resolvable Types + Inlinable

If an invoke call appears inside another function the compiler is at liberty to entirely inline it away. For example, the optimised IR associated to bar is

julia> bar(x::Float64) = invoke(foo, Tuple{Any}, x);

julia> @code_warntype optimize=true bar(5.0)
MethodInstance for bar(::Float64)
  from bar(x::Float64) @ Main REPL[16]:1
Arguments
  #self#::Core.Const(bar)
  x::Float64
Body::Float64
1 ─ %1 = Base.mul_float(5.0, x)::Float64
└──      return %1

2. Statically-Resolvable Types + Not Inlinable

In instances where the compiler is able to resolve the MethodInstance statically, but not inline the call away, it will produce an :invoke Expr (not to be confused with the function invoke). For example

julia> baz(x::Float64) = @noinline invoke(foo, Tuple{Any}, x)
baz (generic function with 1 method)

julia> @code_warntype optimize=true baz(5.0)
MethodInstance for baz(::Float64)
  from baz(x::Float64) @ Main REPL[23]:1
Arguments
  #self#::Core.Const(baz)
  x::Float64
Body::Float64
1 ─ %1 = invoke Main.foo(x::Float64)::Float64
└──      return %1

Interestingly, one cannot tell from looking at the displayed IR that the call to invoke was there -- i.e. the print-out of the IR is the same as that which would have been generated if the definition of baz were

baz(x::Float64) = @noinline foo(x)

The difference, however, is most certainly there:

julia> ci = Base.code_typed(baz, Tuple{Float64})[1][1]
CodeInfo(
1 ─ %1 = invoke Main.foo(x::Float64)::Float64
└──      return %1
)

julia> mi = ci.code[1].args[1]
MethodInstance for foo(::Float64)

julia> mi.def
foo(x)
     @ Main REPL[2]:1

By looking at the def field of the MethodInstance associated to the :invoke Expr, we see that it has specialised the foo(x::Any) method of foo.

3. Dynamically-Resolved Types

If the types are not known to the invoke call statically, then a :call to invoke (the built-in) appears in the IR. For example:

julia> function bleh(x::Float64)
           if @noinline randn() > 0
               T = Tuple{Float64}
           else
               T = Tuple{Any}
           end
           return invoke(foo, T, x)
       end;

julia> @code_warntype optimize=true bleh(5.0)
MethodInstance for bleh(::Float64)
  from bleh(x::Float64) @ Main REPL[58]:1
Arguments
  #self#::Core.Const(bleh)
  x::Float64
Body::Any
1 ─ %1 = invoke Main.randn()::Float64
│   %2 = invoke Main.:>(%1::Float64, 0::Int64)::Bool
└──      goto #3 if not %2
2 ─      goto #4
3 ─      nothing
4 ┄ %6 = φ (#2 => Tuple{Float64}, #3 => Tuple{Any})::Union{Type{Tuple{Any}}, Type{Tuple{Float64}}}
│   %7 = Main.invoke(Main.foo, %6, x)::Any
└──      return %7

In this case it is plain to see what is going on from the Main.invoke :call expression near the end of the IR.

Differentiating invoke

invoke complicates the AD story only slightly. We address each of the above cases in turn:

1. Statically-Resolvable Types + Inlinable

In this case, we need to avoid inlining if there is a method of rrule!! which is applicable to the set of types provided to invoke. If this is not the case, then we can safetly inline.

2. Statically-Resolvable Types + Not Inlinable

Derive a rule as usual, but the code which we look up must be the code based on the types provided, rather than the values of the arguments. If there's an applicable method of rrule!!, use that, otherwise derive the rule.

3. Dynamically-Resolved Types

Basically the same as 2.

Implementation Strategy

I think the way to do this is to implement a method of is_primitive which calls is_primitive on the function and the types provided to invoke. Something like

function is_primitive(ctx, ::Type{Tuple{typeof(invoke), typeof(f) T, Targs...}})
    return is_primitive(ctx, Tuple{typeof(f), T...})
end

(this code won't run as written, but it should illustrate what I mean).

We then need to modify the implementation of build_rrule to keep track of both the types which identify the method that we specialise, and the types of the arguments that we pass to the specialised method. At present, they are always the same thing.

Summary

Supporting invoke is a bit of work, but should ultimately be fairly straightforward if the above is all correct.

willtebbutt commented 1 month ago

212 largely solves this problem. In particular, it basically resolves 1 and 2 above, but will not prevent inlining away things which have rules. In particular it will just mean that if you are unlucky enough to call invoke, and hit a method where you might want to apply an rrule!!, then it will be missed. This should not cause correctness issues, but may cause some calls to invoke not to work. On the basis that invoke usage is fairly uncommon, I'm not too worried about this.

Completely resolving this problem looks likely to be a bit more work, and to be a bit of a pain. This is because the compiler "lowers" invoke calls to :invoke expressions prior to hitting inlining. For example,

Expr(:call, invoke, f, TT, args...)

gets converted to

Expr(:invoke, a_method_instance, f, args...)

This is a problem because we lose access to TT, which is needed to assess whether or not somthing is a primitive.

In order to resolve this, I believe it will be necessary to customise the compilation pipeline earlier on. This is moderately straightforward, but I'm inclined to leave it until we find it to be a problem in practice, given that there are other more pressing things to deal with.

sunxd3 commented 1 month ago

Just out of curiosity, would invokelatest create extra trouble beyond invoke?

willtebbutt commented 1 month ago

Almost certainly -- I've not looked in to exactly what happens yet, but I suspect our best bet will be to provide some kind of informative error if someone attempts to AD through invokelatest.

willtebbutt commented 1 month ago

Didn't mean to close this.