dfdx / Ghost.jl

The Code Tracer
MIT License
48 stars 6 forks source link

`Vararg`s get clobbered? #26

Open darsnack opened 2 years ago

darsnack commented 2 years ago

I have the following code:

f(x, xs...) = max.(x, xs...)

where max is treated like a primitive. When I trace f(rand(2, 2), rand(2, 2)) with with Ghost, I get

Tape{Dict{Any, Any}}
  inp %1::typeof(f)
  inp %2::Matrix{Float64}
  inp %3::Matrix{Float64}
  %4 = tuple(max, %2)::Tuple{typeof(max), Matrix{Float64}}
  %5 = _apply_iterate(iterate, broadcasted, %4, %3)::Broadcasted{}
  %6 = materialize(%5)::Matrix{Float64}

Compare this to @code_lowered:

CodeInfo(
1 ─ %1 = Core.tuple(Main.max, x)
│   %2 = Core._apply_iterate(Base.iterate, Base.broadcasted, %1, xs)
│   %3 = Base.materialize(%2)
└──      return %3
)

The call in @code_lowered makes sense: %1 and xs are both going to splat correctly. But in the tape, %3 is not going to splat correctly, because it refers to the input matrix instead of the intermediate Vararg.

darsnack commented 2 years ago

An even simpler MWE:

julia> f(x, xs...) = getindex.(x, xs...)
f (generic function with 1 method)

julia> Ghost.trace(f, X, 1)[2]
Tape{Dict{Any, Any}}
  inp %1::typeof(f)
  inp %2::Matrix{Float64}
  inp %3::Int64
  %4 = tuple(getindex, %2)::Tuple{typeof(getindex), Matrix{Float64}}
  %5 = _apply_iterate(iterate, broadcasted, %4, %3)::Broadcasted{}
  %6 = materialize(%5)::Matrix{Float64}

julia> @code_lowered f(X, 1)
CodeInfo(
1 ─ %1 = Core.tuple(Main.getindex, x)
│   %2 = Core._apply_iterate(Base.iterate, Base.broadcasted, %1, xs)
│   %3 = Base.materialize(%2)
└──      return %3
)
darsnack commented 2 years ago

After looking into how to fix this, it seems like an issue when the top-level function being called has a Vararg signature. Since Ghost.trace(f, args...) already uses splatting, we would need to correctly de-sugar args into the signature that f expects in order for the tracer to "see" the splatting in f?

dfdx commented 2 years ago

Busy right now, but seems to be the same as dfdx/Yota.jl#84

dfdx commented 2 years ago

Yes, it's the same issue as the one I linked. Unfortunately, I don't have an immediate fix for it - IRTools adds some magic that I don't know how to mitigate. Recently I looked at Mixtape.jl and CodeInfoTools.jl as a more future-proof alternative to IRTools.jl, but refactoring would take quite a lot of time.

Usually, wrapping a top-level function into another without varargs solves the issue. Would it be sufficient for your current use case?

darsnack commented 2 years ago

In my case, I managed to avoid the Vararg. But my package that uses Ghost is supposed to support any user defined function. So eventual support for Vararg would be nice.

dfdx commented 2 years ago

Agree. My current plan is to take a look at various alternatives to IRTools, which is bottleneck in this issue and to be deprecated anyway, and see what would be easier - update to a new tech stack or dive into internals of IRTools and fix it in the current version. Unfortunately, both options are quite complicated, so no estimated time for resolution yet :disappointed: