dfdx / Umlaut.jl

The Code Tracer
MIT License
32 stars 6 forks source link

Stack traces #33

Closed willtebbutt closed 11 months ago

willtebbutt commented 1 year ago

Currently it's quite hard to know what bit of code causes a failure in Umlaut's tracing functionality.

For example, I have the following:

ERROR: AssertionError: Expected Pair{IRCode,...}, but got Pair{Method, DataType} instead for f=arrayref with argtypes=(Expr, Vector{Any}, Int64)
Stacktrace:
  [1] getcode(f::Any, argtypes::Any)
    @ Umlaut ~/ml/ad_playground/Taped.jl/dev/Umlaut/src/trace.jl:247
  [2] trace!(t::Umlaut.Tracer, v_fargs::Any)
    @ Umlaut ~/ml/ad_playground/Taped.jl/dev/Umlaut/src/trace.jl:478
  [3] trace_call!(::Umlaut.Tracer{TapedContext}, ::Any, ::Vararg{Any})
    @ Umlaut ~/ml/ad_playground/Taped.jl/dev/Umlaut/src/trace.jl:297
  [4] trace_block!(t::Umlaut.Tracer, ir::Core.Compiler.IRCode, bi::Integer, prev_bi::Integer, sparams::Any)
    @ Umlaut ~/ml/ad_playground/Taped.jl/dev/Umlaut/src/trace.jl:334
  [5] trace!(t::Umlaut.Tracer, v_fargs::Any)
    @ Umlaut ~/ml/ad_playground/Taped.jl/dev/Umlaut/src/trace.jl:487
  [6] trace_call!(::Umlaut.Tracer{TapedContext}, ::Any, ::Vararg{Any})
    @ Umlaut ~/ml/ad_playground/Taped.jl/dev/Umlaut/src/trace.jl:297
  [7] trace_block!(t::Umlaut.Tracer, ir::Core.Compiler.IRCode, bi::Integer, prev_bi::Integer, sparams::Any)
    @ Umlaut ~/ml/ad_playground/Taped.jl/dev/Umlaut/src/trace.jl:334
  [8] trace!(t::Umlaut.Tracer, v_fargs::Any)
    @ Umlaut ~/ml/ad_playground/Taped.jl/dev/Umlaut/src/trace.jl:487
  [9] trace(::Any, ::Any, ::Vararg{Any}; ctx::Any, deprecated_kws::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
    @ Umlaut ~/ml/ad_playground/Taped.jl/dev/Umlaut/src/trace.jl:609
 [10] trace_shapes(f::Function, x::Vector{Any})
    @ Taped ~/ml/ad_playground/Taped.jl/src/tracing.jl:46
 [11] top-level scope
    @ REPL[30]:1

It tells me that arrayref is the problematic function, with a particular set of argument types. It would be really helpful if this kind of error generated an additional bit of information that shows where in the code that Umlaut is interpretting we currently are.

dfdx commented 1 year ago

Perhaps you want the (undocumented) function Umlaut.print_stack_trace(), which prints the stack trace so far. Or, if you want to dive right to the middle of tracing, you may try Umlaut.get_latest_tracer_state(), which returns an instance of Tracer, IR of the currently traced function and a tuple of argument Variable's.

In any case, this is a tracing bug and shouldn't happen in normal circumstances. I can reproduce it using exactly the information from the error message:

julia> f = Base.arrayref
arrayref (built-in function)

julia> argtypes=(Expr, Vector{Any}, Int64)
(Expr, Vector{Any}, Int64)

julia> getcode(f, argtypes)
ERROR: AssertionError: Expected Pair{IRCode,...}, but got Pair{Method, DataType} instead for f=arrayref with argtypes=(Expr, Vector{Any}, Int64)
Stacktrace:
 [1] getcode(f::Any, argtypes::Any)
   @ Main ~/work/Umlaut.jl/src/trace.jl:247
 [2] top-level scope
   @ REPL[19]:1

Apparently, Base.code_ircode(f, argtypes; optimize_until="slot2reg") for Base.arrayref returns Pair{Method, ...} instead of Pair{IRCode, ...}. Maybe, I can extract the IR from that Method, but I wonder why you even got to tracing the Base primitive, because usually functions in Base are recorded to the tape as is. Do you use a custom context that doesn't re-use BaseCtx? Can you share a bit more details about your use case?

willtebbutt commented 1 year ago

Do you use a custom context that doesn't re-use BaseCtx? Can you share a bit more details about your use case?

Exactly -- I'm intentionally playing around with what I treat as a primitive. Re my application -- I'm not too worried about being able to solve this bug per-se, I'm more concerned about it being tricky generally to know where in code tracing has broken (if it breaks).

Thanks for pointing me towards Umlaut.print_stack_trace() -- I'll take a look at it

dfdx commented 1 year ago

The main issue as I see it is that you try to trace through a built-in function. Such functions don't behave as usual - the are implemented in C or a low-level subset of Julia and often don't have IR representation that we can trace. Unless you want to dive deep into the Julia compiler and runtime, the only was to avoid such issues is to mark these functions as primitives. Which is as simple as:

Umlaut.isprimitive(YourCtx(), Base.arrayref, args...) = true
willtebbutt commented 11 months ago

I'm closing this because Umlaut.print_stack_trace() solves my problem nicely.