JuliaLang / julia

The Julia Programming Language
https://julialang.org/
MIT License
45.82k stars 5.49k forks source link

Vararg specialization isn't firm enough #34365

Open MasonProtter opened 4 years ago

MasonProtter commented 4 years ago

In Cassette.overdub passes, one often finds that varargs incurs a performance penalty.

For instance,

using Cassette

Cassette.@context TraceCtx

mutable struct Trace
    current::Vector{Any}
    stack::Vector{Any}
    Trace() = new(Any[], Any[])
end

function enter!(t::Trace, args...)
    pair = args => Any[]
    push!(t.current, pair)
    push!(t.stack, t.current)
    t.current = pair.second
    return nothing
end

function exit!(t::Trace)
    t.current = pop!(t.stack)
    return nothing
end

Cassette.prehook(ctx::TraceCtx, args...) = enter!(ctx.metadata, args...)
Cassette.posthook(ctx::TraceCtx, args...) = exit!(ctx.metadata)

trace = Trace()
x, y, z = rand(3)
f(x, y, z) = x*y + y*z

julia> @btime Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z))
  3.315 μs (41 allocations: 1.48 KiB)
0.2360528466104866

Here, the vararg splatting in enter!(t::Trace, args...) accounts for a large fraction of the above allocations and runtime. Switching to args::Vararg{Any, N} where N doesn't seem to help:

julia> function enter!(t::Trace, args::Vararg{Any, N}) where {N}
           pair = args => Any[]
           push!(t.current, pair)
           push!(t.stack, t.current)
           t.current = pair.second
           return nothing
       end
enter! (generic function with 1 method)

julia> @btime Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z))
  3.883 μs (61 allocations: 1.81 KiB)
0.1532882013156685

(note that enter!(t::Trace, args::Vararg{<:T, N}) where {T, N} won't work because it forces all the varargs to be the same subtype of T.)

but a macro I made for getting around this does:

julia> using SpecializeVarargs

julia> @specialize_vararg 5 function enter!(t::Trace, args...)
           pair = args => Any[]
           push!(t.current, pair)
           push!(t.stack, t.current)
           t.current = pair.second
           return nothing
       end
enter! (generic function with 6 methods)

julia> @btime Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z))
  1.595 μs (27 allocations: 1.17 KiB)
0.1532882013156685

What the above macro did was to make 5 copies of the method:

enter!(t::Trace, x1::T1) where {T1}
enter!(t::Trace, x1::T1, x2::T2) where {T1, T2}
enter!(t::Trace, x1::T1, x2::T2, x3::T3) where {T1, T2, T3}
enter!(t::Trace, x1::T1, x2::T2, x3::T3, x4::T4) where {T1, T2, T3, T4}
enter!(t::Trace, x1::T1, x2::T2, x3::T3, x4::T4, x5::T5, args...) where {T1, T2, T3, T4, T5}

This is a bit clunky and ideally wouldn't be necessary if args::Vararg{Any, N} or something similar was more strict.

JeffBezanson commented 4 years ago

Seems like a use case for a @specialize. (Note that while @specialize exists it doesn't yet have this functionality.)

timholy commented 4 years ago

xref #33978