tpapp / TransformVariables.jl

Transformations to contrained variables from ℝⁿ.
Other
66 stars 14 forks source link

Replace mapreduce(f,op,...) with reduce(op,map(f,...)) for type stability #80

Closed chriselrod closed 3 years ago

chriselrod commented 3 years ago

mapreduce tends to be type unstable with hetogenous tuples while map and reduce are not.

Note the Unions before:

julia> trft
(CL_base = TransformVariables.ShiftedExp{true, Float64}(0.0), CL_logwt = TransformVariables.Identity(), v_base = TransformVariables.ShiftedExp{true, Float64}(0.0), ω_1 = TransformVariables.ShiftedExp{true, Float64}(0.0), ω_2 = TransformVariables.ShiftedExp{true, Float64}(0.0), σ_0 = TransformVariables.ShiftedExp{true, Float64}(0.0))

julia> @code_typed as(trft)
CodeInfo(
1 ── %1  = Base.sle_int(1, 1)::Bool
└───       goto #3 if not %1
2 ── %3  = Base.sle_int(1, 0)::Bool
└───       goto #4
3 ──       nothing::Nothing
4 ┄─ %6  = φ (#2 => %3, #3 => false)::Bool
└───       goto #6 if not %6
5 ──       invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───       unreachable
6 ──       goto #7
7 ──       goto #8
8 ──       goto #9
9 ──       goto #10
10 ─ %14 = Base.getfield(transformations, 1)::TransformVariables.ShiftedExp{true, Float64}
└─── %15 = Core.tuple(%14, 2)::Tuple{TransformVariables.ShiftedExp{true, Float64}, Int64}
11 ┄ %16 = φ (#10 => 1, #34 => %75)::Int64
│    %17 = φ (#10 => %15, #34 => %59)::Union{Nothing, Tuple{Union{TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}}, Int64}}
└───       goto #35 if not true
12 ─ %19 = π (%17, Union{Tuple{TransformVariables.Identity, Int64}, Tuple{TransformVariables.ShiftedExp{true, Float64}, Int64}})
│    %20 = (isa)(%19, Tuple{TransformVariables.Identity, Int64})::Bool
└───       goto #14 if not %20
13 ─ %22 = π (%19, Tuple{TransformVariables.Identity, Int64})
│    %23 = Base.getfield(%22, 2, true)::Union{TransformVariables.Identity, Int64}
└───       goto #17
14 ─ %25 = (isa)(%19, Tuple{TransformVariables.ShiftedExp{true, Float64}, Int64})::Bool
└───       goto #16 if not %25
15 ─ %27 = π (%19, Tuple{TransformVariables.ShiftedExp{true, Float64}, Int64})
│    %28 = Base.getfield(%27, 2, true)::Union{Int64, TransformVariables.ShiftedExp{true, Float64}}
└───       goto #17
16 ─       Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{}
└───       unreachable
17 ┄ %32 = φ (#13 => %23, #15 => %28)::Union{TransformVariables.Identity, Int64, TransformVariables.ShiftedExp{true, Float64}}
│    %33 = Base.iterate::typeof(iterate)
│    %34 = (isa)(%32, TransformVariables.Identity)::Bool
└───       goto #19 if not %34
18 ─ %36 = π (%32, TransformVariables.Identity)
│    %37 = invoke %33(_2::NamedTuple{(:CL_base, :CL_logwt, :v_base, :ω_1, :ω_2, :σ_0), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}}}, %36::TransformVariables.Identity)::Union{Nothing, Tuple{Union{TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}}, Int64}}
└───       goto #27
19 ─ %39 = (isa)(%32, Int64)::Bool
└───       goto #24 if not %39
20 ─ %41 = π (%32, Int64)
│    %42 = Base.slt_int(6, %41)::Bool
└───       goto #22 if not %42
21 ─ %44 = Base.nothing::Nothing
└───       goto #23
22 ─ %46 = Base.getfield(transformations, %41)::Union{TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}}
│    %47 = Base.add_int(%41, 1)::Int64
│    %48 = Core.tuple(%46, %47)::Tuple{Union{TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}}, Int64}
└───       goto #23
23 ┄ %50 = φ (#21 => %44, #22 => %48)::Union{Nothing, Tuple{Union{TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}}, Int64}}
└───       goto #27
24 ─ %52 = (isa)(%32, TransformVariables.ShiftedExp{true, Float64})::Bool
└───       goto #26 if not %52
25 ─ %54 = π (%32, TransformVariables.ShiftedExp{true, Float64})
│    %55 = invoke %33(_2::NamedTuple{(:CL_base, :CL_logwt, :v_base, :ω_1, :ω_2, :σ_0), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}}}, %54::TransformVariables.ShiftedExp{true, Float64})::Union{Nothing, Tuple{Union{TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}}, Int64}}
└───       goto #27
26 ─       Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{}
└───       unreachable
27 ┄ %59 = φ (#18 => %37, #23 => %50, #25 => %55)::Union{Nothing, Tuple{Union{TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}}, Int64}}
│    %60 = (%59 === Base.nothing)::Bool
└───       goto #29 if not %60
28 ─       goto #35
29 ─ %63 = π (%59, Tuple{Union{TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}}, Int64})
│    %64 = Base.getfield(%63, 1, true)::Union{TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}}
│    %65 = (isa)(%64, TransformVariables.Identity)::Bool
└───       goto #31 if not %65
30 ─ %67 = Base.add_int(%16, 1)::Int64
└───       goto #34
31 ─ %69 = (isa)(%64, TransformVariables.ShiftedExp{true, Float64})::Bool
└───       goto #33 if not %69
32 ─ %71 = Base.add_int(%16, 1)::Int64
└───       goto #34
33 ─       Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{}
└───       unreachable
34 ┄ %75 = φ (#30 => %67, #32 => %71)::Int64
└───       goto #11
35 ┄       goto #36
36 ─       goto #37
37 ─       goto #38
38 ─       goto #39
39 ─       goto #40
40 ─       goto #41
41 ─       goto #42
42 ─       goto #43
43 ─ %85 = %new(TransformVariables.TransformTuple{NamedTuple{(:CL_base, :CL_logwt, :v_base, :ω_1, :ω_2, :σ_0), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}}}}, transformations, %16)::TransformVariables.TransformTuple{NamedTuple{(:CL_base, :CL_logwt, :v_base, :ω_1, :ω_2, :σ_0), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}}}}
└───       goto #44
44 ─       return %85
) => TransformVariables.TransformTuple{NamedTuple{(:CL_base, :CL_logwt, :v_base, :ω_1, :ω_2, :σ_0), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}}}}

versus after

 @code_typed as(trft)
CodeInfo(
1 ──       nothing::Nothing
2 ┄─ %2  = φ (#1 => 1, #8 => %16)::Int64
│    %3  = φ (#1 => 2, #8 => %11)::Int64
└───       goto #9 if not true
3 ── %5  = Base.slt_int(6, %3)::Bool
└───       goto #5 if not %5
4 ──       goto #6
5 ── %8  = Base.getfield($(QuoteNode((CL_base = 1, CL_logwt = 1, v_base = 1, ω_1 = 1, ω_2 = 1, σ_0 = 1))), %3)::Int64
│    %9  = Base.add_int(%3, 1)::Int64
└───       goto #6
6 ┄─ %11 = φ (#5 => %9)::Int64
│    %12 = φ (#4 => true, #5 => false)::Bool
│    %13 = φ (#5 => %8)::Int64
└───       goto #8 if not %12
7 ──       goto #9
8 ── %16 = Base.add_int(%2, %13)::Int64
└───       goto #2
9 ┄─       goto #10
10 ─       goto #11
11 ─       goto #12
12 ─       goto #13
13 ─       goto #14
14 ─       goto #15
15 ─       goto #16
16 ─       goto #17
17 ─       goto #18
18 ─       goto #19
19 ─       goto #20
20 ─       goto #21
21 ─ %30 = %new(TransformVariables.TransformTuple{NamedTuple{(:CL_base, :CL_logwt, :v_base, :ω_1, :ω_2, :σ_0), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}}}}, transformations, %2)::TransformVariables.TransformTuple{NamedTuple{(:CL_base, :CL_logwt, :v_base, :ω_1, :ω_2, :σ_0), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}}}}
└───       goto #22
22 ─       return %30
) => TransformVariables.TransformTuple{NamedTuple{(:CL_base, :CL_logwt, :v_base, :ω_1, :ω_2, :σ_0), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}}}}

Benchmarks before and after:

julia> @btime as($trft) # before
  393.995 ns (19 allocations: 544 bytes)
TransformVariables.TransformTuple{NamedTuple{(:CL_base, :CL_logwt, :v_base, :ω_1, :ω_2, :σ_0), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}}}}((CL_base = TransformVariables.ShiftedExp{true, Float64}(0.0), CL_logwt = TransformVariables.Identity(), v_base = TransformVariables.ShiftedExp{true, Float64}(0.0), ω_1 = TransformVariables.ShiftedExp{true, Float64}(0.0), ω_2 = TransformVariables.ShiftedExp{true, Float64}(0.0), σ_0 = TransformVariables.ShiftedExp{true, Float64}(0.0)), 6)

julia> @btime as($trft) # after
  1.649 ns (0 allocations: 0 bytes)
TransformVariables.TransformTuple{NamedTuple{(:CL_base, :CL_logwt, :v_base, :ω_1, :ω_2, :σ_0), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}}}}((CL_base = TransformVariables.ShiftedExp{true, Float64}(0.0), CL_logwt = TransformVariables.Identity(), v_base = TransformVariables.ShiftedExp{true, Float64}(0.0), ω_1 = TransformVariables.ShiftedExp{true, Float64}(0.0), ω_2 = TransformVariables.ShiftedExp{true, Float64}(0.0), σ_0 = TransformVariables.ShiftedExp{true, Float64}(0.0)), 6)
tpapp commented 3 years ago

Thanks. Can you please add some @inferred tests so this does not happen again? Anywhere in the test file will do.

chriselrod commented 3 years ago

Unfortunately, I'm not sure how to make a test that fails on master to actually catch this issue. On master:

julia> @btime TransformVariables._sum_dimensions($trft)
  410.345 ns (19 allocations: 544 bytes)
6

julia> @inferred TransformVariables._sum_dimensions(trft)
6

julia> @code_warntype TransformVariables._sum_dimensions(trft)
Variables
  #self#::Core.Const(TransformVariables._sum_dimensions)
  transformations::NamedTuple{(:CL_base, :CL_logwt, :v_base, :ω_1, :ω_2, :σ_0), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}, TransformVariables.ShiftedExp{true, Float64}}}

Body::Int64
1 ─ %1 = (:init,)::Core.Const((:init,))
│   %2 = Core.apply_type(Core.NamedTuple, %1)::Core.Const(NamedTuple{(:init,), T} where T<:Tuple)
│   %3 = Core.tuple(0)::Core.Const((0,))
│   %4 = (%2)(%3)::Core.Const((init = 0,))
│   %5 = Core.kwfunc(TransformVariables.mapreduce)::Core.Const(Base.var"#mapreduce##kw"())
│   %6 = (%5)(%4, TransformVariables.mapreduce, TransformVariables.dimension, TransformVariables.:+, transformations)::Int64
└──      return %6

The return type is stable and @code_warntype also looks fine, but internally mapreduce isn't and the problems show up in Cthulhu.@descend, @code_typed, and benchmarks.

I could do something like occursin on the result of code_typed.

So maybe this is best thought of us a workaround for a base issue.

tpapp commented 3 years ago

Thanks for investigating this. Can you please add a comment then into the source with a short explanation, and maybe a reference to this issue, so that we will not modify it inadvertently?

Also, is there an issue in Julia about this?

chriselrod commented 3 years ago

I added the same comment to both instances. I'm not aware of a Julia issue.

I also added a test using @test iszero(@allocated ...) which passes on this PR but fails on master.

tpapp commented 3 years ago

Thanks for the nice PR, merging, and will tag a release too.