Open darsnack opened 2 years ago
An even simpler MWE:
julia> f(x, xs...) = getindex.(x, xs...)
f (generic function with 1 method)
julia> Ghost.trace(f, X, 1)[2]
Tape{Dict{Any, Any}}
inp %1::typeof(f)
inp %2::Matrix{Float64}
inp %3::Int64
%4 = tuple(getindex, %2)::Tuple{typeof(getindex), Matrix{Float64}}
%5 = _apply_iterate(iterate, broadcasted, %4, %3)::Broadcasted{}
%6 = materialize(%5)::Matrix{Float64}
julia> @code_lowered f(X, 1)
CodeInfo(
1 ─ %1 = Core.tuple(Main.getindex, x)
│ %2 = Core._apply_iterate(Base.iterate, Base.broadcasted, %1, xs)
│ %3 = Base.materialize(%2)
└── return %3
)
After looking into how to fix this, it seems like an issue when the top-level function being called has a Vararg
signature. Since Ghost.trace(f, args...)
already uses splatting, we would need to correctly de-sugar args
into the signature that f
expects in order for the tracer to "see" the splatting in f
?
Busy right now, but seems to be the same as dfdx/Yota.jl#84
Yes, it's the same issue as the one I linked. Unfortunately, I don't have an immediate fix for it - IRTools adds some magic that I don't know how to mitigate. Recently I looked at Mixtape.jl and CodeInfoTools.jl as a more future-proof alternative to IRTools.jl, but refactoring would take quite a lot of time.
Usually, wrapping a top-level function into another without varargs solves the issue. Would it be sufficient for your current use case?
In my case, I managed to avoid the Vararg
. But my package that uses Ghost is supposed to support any user defined function. So eventual support for Vararg
would be nice.
Agree. My current plan is to take a look at various alternatives to IRTools, which is bottleneck in this issue and to be deprecated anyway, and see what would be easier - update to a new tech stack or dive into internals of IRTools and fix it in the current version. Unfortunately, both options are quite complicated, so no estimated time for resolution yet :disappointed:
I have the following code:
where
max
is treated like a primitive. When I tracef(rand(2, 2), rand(2, 2))
with with Ghost, I getCompare this to
@code_lowered
:The call in
@code_lowered
makes sense:%1
andxs
are both going to splat correctly. But in the tape,%3
is not going to splat correctly, because it refers to the input matrix instead of the intermediateVararg
.