JuliaStats / StatsModels.jl

Specifying, fitting, and evaluating statistical models in Julia
251 stars 31 forks source link

`fit` is very slow for new formulas #220

Open evanfields opened 3 years ago

evanfields commented 3 years ago

Calling StatsModels.fit with a not yet seen formula seems to trigger pretty slow compilation, even if a structurally equivalent formula with different names has been seen before. Triggering fit with a formula which has been seen before is very fast.

The below reproducing example using GLM and DataFrames, and closely mimics how I stumbled upon this issue in the wild. I'm not familiar with the StatsModels/GLM internals, but if this example isn't minimal enough I can try to drill down.

julia> using StatsModels, DataFrames, GLM

julia> df = DataFrame(rand(100,100));

julia> function f(df, a, b, c)
           reg_form = Term(a) ~ Term(b) + Term(c)
           return r2(lm(reg_form, df))
       end
f (generic function with 1 method)

julia> # precompile f with arg types dataframe, symbol, symbol, symbol

julia> f(df, :x1, :x2, :x3)
0.0033857274581728936

julia> # That was really slow, around 7 secons. Future calls with the exact same symbols are fast

julia> # call f with the same arg types, but new symbol values

julia> @elapsed f(df, :x1, :x2, :x4)
0.1441765

julia> # call f again with already-seen symbol values, and it's super fast

julia> @elapsed f(df, :x1, :x2, :x4)
0.00016

julia> versioninfo()
Julia Version 1.6.0
Commit f9720dc2eb (2021-03-24 12:55 UTC)
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: AMD Ryzen 5 1600X Six-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, znver1)
Environment:
  JULIAPATH = C:\Users\ejfie\AppData\Local\Programs\Julia-1.6.0\bin
  JULIA_NUM_THREADS = 4

(@v1.6) pkg> st StatsModels
      Status `C:\Users\ejfie\.julia\environments\v1.6\Project.toml`
  [3eaba693] StatsModels v0.6.21

julia> using Profile

julia> @profile begin
           for i in 5:20
               symb = Symbol("x$i")
               f(df, :x1, :x2, symb)
           end
       end

julia> Profile.print(mincount=50,maxdepth=30)
Overhead ╎ [+additional indent] Count File:Line; Function
=========================================================
   ╎1254 @Base\client.jl:485; _start()
   ╎ 1254 @Base\client.jl:302; exec_options(opts::Base.JLOptions)
   ╎  1254 @Base\client.jl:372; run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
   ╎   1254 @Base\essentials.jl:706; invokelatest
   ╎    1254 @Base\essentials.jl:708; #invokelatest#2
   ╎     1254 @Base\client.jl:387; (::Base.var"#874#876"{Bool, Bool, Bool})(REPL::Module)
   ╎    ╎ 1254 C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:305; run_repl(repl::REPL.AbstractREPL, consumer::Any)
   ╎    ╎  1254 C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:317; run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool)
   ╎    ╎   1254 C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:185; start_repl_backend(backend::REPL.REPLBackend, consumer::Any)
   ╎    ╎    1254 C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:200; repl_backend_loop(backend::REPL.REPLBackend)
   ╎    ╎     1254 C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:139; eval_user_input(ast::Any, backend::REPL.REPLBackend)
   ╎    ╎    ╎ 1254 @Base\boot.jl:360; eval
   ╎    ╎    ╎  1254 ...uildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\Profile\src\Profile.jl:28; top-level scope
   ╎    ╎    ╎   1254 REPL[9]:4; macro expansion
   ╎    ╎    ╎    1254 REPL[3]:3; f(df::DataFrame, a::Symbol, b::Symbol, c::Symbol)
   ╎    ╎    ╎     1254 @GLM\src\lm.jl:156; lm
   ╎    ╎    ╎    ╎ 1254 @GLM\src\lm.jl:156; lm
   ╎    ╎    ╎    ╎  1254 @GLM\src\lm.jl:156; #lm#2
   ╎    ╎    ╎    ╎   1254 @StatsModels\src\statsmodel.jl:82; fit
   ╎    ╎    ╎    ╎    518  @StatsModels\src\statsmodel.jl:85; fit(::Type{LinearModel}, f::FormulaTerm{Term, Tuple{Term, Term}}, data::DataFrame, args::Nothing; contrasts::Dict{Symbol, An...
   ╎    ╎    ╎    ╎     518  @StatsModels\src\modelframe.jl:74; (::Core.var"#Type##kw")(::NamedTuple{(:model, :contrasts), Tuple{UnionAll, Dict{Symbol, Any}}}, ::Type{ModelFrame}, f::Formu...
   ╎    ╎    ╎    ╎    ╎ 166  @StatsModels\src\modelframe.jl:74; ModelFrame(f::FormulaTerm{Term, Tuple{Term, Term}}, data::NamedTuple{(:x1, :x2, :x3, :x4, :x5, :x6, :x7, :x8, :x9, :x10, :x...
101╎    ╎    ╎    ╎    ╎  166  @StatsModels\src\modelframe.jl:69; missing_omit(data::NamedTuple{(:x1, :x2, :x3, :x4, :x5, :x6, :x7, :x8, :x9, :x10, :x11, :x12, :x13, :x14, :x15, :x16, :x17,...
   ╎    ╎    ╎    ╎    ╎   64   @Base\compiler\typeinfer.jl:921; typeinf_ext_toplevel(mi::Core.MethodInstance, world::UInt64)
   ╎    ╎    ╎    ╎    ╎    64   @Base\compiler\typeinfer.jl:925; typeinf_ext_toplevel(interp::Core.Compiler.NativeInterpreter, linfo::Core.MethodInstance)
   ╎    ╎    ╎    ╎    ╎     58   @Base\compiler\typeinfer.jl:892; typeinf_ext(interp::Core.Compiler.NativeInterpreter, mi::Core.MethodInstance)
   ╎    ╎    ╎    ╎    ╎    ╎ 58   @Base\compiler\typeinfer.jl:209; typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
   ╎    ╎    ╎    ╎    ╎    ╎  53   @Base\compiler\typeinfer.jl:214; _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
   ╎    ╎    ╎    ╎    ╎    ╎   53   @Base\compiler\abstractinterpretation.jl:1520; typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
221╎    ╎    ╎    ╎    ╎ 273  @StatsModels\src\modelframe.jl:76; ModelFrame(f::FormulaTerm{Term, Tuple{Term, Term}}, data::NamedTuple{(:x1, :x2, :x3, :x4, :x5, :x6, :x7, :x8, :x9, :x10, :x...
   ╎    ╎    ╎    ╎    ╎  52   @Base\compiler\typeinfer.jl:921; typeinf_ext_toplevel(mi::Core.MethodInstance, world::UInt64)
   ╎    ╎    ╎    ╎    ╎   52   @Base\compiler\typeinfer.jl:925; typeinf_ext_toplevel(interp::Core.Compiler.NativeInterpreter, linfo::Core.MethodInstance)
   ╎    ╎    ╎    ╎    ╎    52   @Base\compiler\typeinfer.jl:892; typeinf_ext(interp::Core.Compiler.NativeInterpreter, mi::Core.MethodInstance)
   ╎    ╎    ╎    ╎    ╎     52   @Base\compiler\typeinfer.jl:209; typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
   ╎    ╎    ╎    ╎    ╎    ╎ 51   @Base\compiler\typeinfer.jl:214; _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
   ╎    ╎    ╎    ╎    ╎    ╎  51   @Base\compiler\abstractinterpretation.jl:1520; typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
   ╎    ╎    ╎    ╎    ╎    ╎   51   @Base\compiler\abstractinterpretation.jl:1462; typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
   ╎    ╎    ╎    ╎    ╎    ╎    51   @Base\compiler\abstractinterpretation.jl:1167; abstract_eval_statement(interp::Core.Compiler.NativeInterpreter, e::Any, vtypes::Vector{Any}, sv::Core.Compiler.Infer...
   ╎    ╎    ╎    ╎    ╎    ╎     51   @Base\compiler\abstractinterpretation.jl:1040; abstract_call(interp::Core.Compiler.NativeInterpreter, fargs::Vector{Any}, argtypes::Vector{Any}, sv::Core.Compiler.I...
 70╎    ╎    ╎    ╎    ╎ 79   @StatsModels\src\modelframe.jl:79; ModelFrame(f::FormulaTerm{Term, Tuple{Term, Term}}, data::NamedTuple{(:x1, :x2, :x3, :x4, :x5, :x6, :x7, :x8, :x9, :x10, :x...
 85╎    ╎    ╎    ╎    625  @StatsModels\src\statsmodel.jl:86; fit(::Type{LinearModel}, f::FormulaTerm{Term, Tuple{Term, Term}}, data::DataFrame, args::Nothing; contrasts::Dict{Symbol, An...
   ╎    ╎    ╎    ╎     522  @StatsModels\src\modelframe.jl:222; ModelMatrix(mf::ModelFrame{NamedTuple{(:x1, :x2, :x5), Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}, LinearMod...
   ╎    ╎    ╎    ╎    ╎ 522  @StatsModels\src\modelframe.jl:218; ModelMatrix{Matrix{Float64}}(mf::ModelFrame{NamedTuple{(:x1, :x2, :x5), Tuple{Vector{Float64}, Vector{Float64}, Vector{Floa...
   ╎    ╎    ╎    ╎    ╎  522  @StatsModels\src\modelframe.jl:147; modelmatrix
432╎    ╎    ╎    ╎    ╎   522  @StatsModels\src\modelframe.jl:147; modelmatrix(mf::ModelFrame{NamedTuple{(:x1, :x2, :x5), Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}, LinearM...
   ╎    ╎    ╎    ╎    ╎    90   @Base\compiler\typeinfer.jl:921; typeinf_ext_toplevel(mi::Core.MethodInstance, world::UInt64)
   ╎    ╎    ╎    ╎    ╎     90   @Base\compiler\typeinfer.jl:925; typeinf_ext_toplevel(interp::Core.Compiler.NativeInterpreter, linfo::Core.MethodInstance)
   ╎    ╎    ╎    ╎    ╎    ╎ 90   @Base\compiler\typeinfer.jl:892; typeinf_ext(interp::Core.Compiler.NativeInterpreter, mi::Core.MethodInstance)
   ╎    ╎    ╎    ╎    ╎    ╎  90   @Base\compiler\typeinfer.jl:209; typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
   ╎    ╎    ╎    ╎    ╎    ╎   89   @Base\compiler\typeinfer.jl:214; _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
   ╎    ╎    ╎    ╎    ╎    ╎    89   @Base\compiler\abstractinterpretation.jl:1520; typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
   ╎    ╎    ╎    ╎    ╎    ╎     89   @Base\compiler\abstractinterpretation.jl:1462; typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
 62╎    ╎    ╎    ╎    65   @StatsModels\src\statsmodel.jl:87; fit(::Type{LinearModel}, f::FormulaTerm{Term, Tuple{Term, Term}}, data::DataFrame, args::Nothing; contrasts::Dict{Symbol, An...
Total snapshots: 1269
pdeffebach commented 3 years ago

I can reproduce. Here are my timings

julia> using GLM, DataFrames

julia> df = DataFrame(rand(100,100));

julia> function f(df, a, b, c)
           reg_form = Term(a) ~ Term(b) + Term(c)
           return r2(lm(reg_form, df))
       end
f (generic function with 1 method)

julia> @time f(df, :x1, :x2, :x3)
  7.844389 seconds (22.20 M allocations: 1.386 GiB, 5.80% gc time)
0.002045333822805362

julia> @time f(df, :x1, :x2, :x3)
  0.000191 seconds (235 allocations: 38.891 KiB)
0.002045333822805362

julia> @time f(df, :x1, :x2, :x4)
  0.167007 seconds (163.14 k allocations: 9.977 MiB, 98.56% compilation time)
0.0015202494005057687

julia> @time f(df, :x1, :x2, :x4)
  0.000167 seconds (235 allocations: 38.891 KiB)
0.0015202494005057687

julia> @time f(df, :x1, :x2, :x5)
  0.192700 seconds (163.12 k allocations: 9.974 MiB, 8.88% gc time, 98.73% compilation time)
0.041061833597036856
kleinschmidt commented 3 years ago

I suspect this has to do with the modelcols or ModelMatrix methods specializing on the data namedtuple (where the names are type parameters). Currently, we implement generic Tables.jl support by coercing the input data to a NamedTuple of vectors before doing anything with it. I wonder whether there's some kind of alternative strategy which would a) avoid the conversion and b) not take such a big compilation hit. Something like [Tables.columns]()

I think one roadblock for just using getcolumn or columns everywhere is that we're also relying on the namedtuple type in order to dispatch (and I suspect avoid method ambiguities) AND to special case handling a single row vs. an entire table (e.g. for interaction terms). But we could get around that with some kind of internal wrapper types (or maybe Tables.jl provides something for this?)

kleinschmidt commented 3 years ago

Actually, I think using the Tables.Columns and Tables.Row wrappers would work just fine. They support everything that the NamedTuple does and are IIUC lazy, and also provide dispatch targets.

kleinschmidt commented 3 years ago

I've played around with this a bit more and I can't reproduce it using just apply_schema and modelcols. I suspected that because modelcols has the NamedTuple of the data as one argument it would specialize and trigger re-compilation but it doesn't seem to be the case. Here's what I tried:

julia> function g(df, a, b, c)
           reg_form = Term(a) ~ Term(b) + Term(c)
           return apply_schema(reg_form, schema(reg_form, df), RegressionModel)
       end
g (generic function with 1 method)

julia> @time g(df, :x1, :x2, :x5);
  0.122910 seconds (392.04 k allocations: 23.309 MiB, 99.87% compilation time)

julia> @time g(df, :x1, :x2, :x5);
  0.000068 seconds (128 allocations: 16.672 KiB)

julia> @time g(df, :x1, :x2, :x6);
  0.000065 seconds (128 allocations: 16.672 KiB)

julia> h(df, args...) = modelcols(g(df, args...), df)
h (generic function with 1 method)

julia> @time h(df, :x1, :x2, :x5);
  0.190115 seconds (586.57 k allocations: 35.274 MiB, 99.93% compilation time)

julia> @time h(df, :x1, :x2, :x5);
  0.000084 seconds (159 allocations: 30.922 KiB)

julia> @time h(df, :x1, :x2, :x6);
  0.000088 seconds (159 allocations: 30.922 KiB)

Even the first run with a new formula is fast after any formula with that structure has been compiled once.

So I suspect it has something to do with the ModelMatrix or ModelFrame wrappers...

matthieugomez commented 3 years ago

Using a type that does not need to be specialized over and over again would be awesome! Or maybe use @nospecialize everywhere.

kleinschmidt commented 3 years ago

Yeah it's strange...I'd figured that any specialization would hit those paths too but it doesn't seem like it. I'll have to dig into where the specialization is taking place (or, someone will ;)

Unfortunately Tables.Columns has a type parameter for the wrapped table type so I don't think it'll solve the problem in all cases, although it may help with sources that don't have structural information like column names/types in the type.

matthieugomez commented 3 years ago

Unfortunately Tables.Columns has a type parameter for the wrapped table type so I don't think it'll solve the problem in all cases, although it may help with sources that don't have structural information like column names/types in the type.

Yes, but it's actually perfect no? If I pass a DataFrame then it won't specialize, whereas if I passe a ColumnTable it will specialize — that's to be expected.

matthieugomez commented 3 years ago

Btw I think the slowdown comes from missing_omit that creates a new namedtuple type depending on variables in the formula.

kleinschmidt commented 3 years ago

Ahhh that's interesting then, and would explain why I'm not hitting it in the tests above. Maybe the specialization was a red herring then. I wonder if there's a generic-tables-compatible way of doing missing omit...

pdeffebach commented 3 years ago

You can do TableOperations.filter maybe.

There is also skipmissings to identify all the observations that are missing.

matthieugomez commented 3 years ago

I think it’s still about specialization — it’s just that everything after missing_omit is respecialized to the new dataset. Yes I think the way forward would be to write missing_omit that takes a Table.Columns and create a Table.Colums if it’s possible.

matthieugomez commented 3 years ago

cf https://github.com/JuliaData/TableOperations.jl/issues/7

pdeffebach commented 3 years ago

fwiw I think it's likely that @nospecialize will help in this scenario.

julia> namedtuples = map(1:50) do _
           names = rand('a':'z', 10);
           v = [Symbol(n) => rand(10) for n in names]
           (;v...)
       end;

julia> function foo(t)
           nms = collect(keys(t))
           means = map(mean, collect(values(t)))
           return nms .=> means
       end;

julia> @time foo(namedtuples[1]);
  0.093913 seconds (214.92 k allocations: 12.716 MiB, 99.96% compilation time)

julia> @time foo(namedtuples[1]);
  0.000012 seconds (4 allocations: 720 bytes)

julia> @time foo(namedtuples[2]);
  0.061658 seconds (160.72 k allocations: 9.307 MiB, 99.95% compilation time)

julia> @time foo(namedtuples[2]);
  0.000013 seconds (4 allocations: 704 bytes)

julia> function bar(@nospecialize t)
           nms = collect(keys(t))
           means = map(mean, collect(values(t)))
           return nms .=> means
       end;
julia> @time bar(namedtuples[11]);
  0.034917 seconds (29.15 k allocations: 1.985 MiB, 99.64% compilation time)

julia> @time bar(namedtuples[11]);
  0.000042 seconds (8 allocations: 896 bytes)

julia> @time bar(namedtuples[12]);
  0.009303 seconds (3.31 k allocations: 212.256 KiB, 98.54% compilation time)

julia> @time bar(namedtuples[12]);
  0.000046 seconds (8 allocations: 800 bytes)
kleinschmidt commented 3 years ago

Since Table 1.6 https://github.com/JuliaData/Tables.jl/releases/tag/v1.6.0 Columns will actually reliably return a Columns object so we could use that for dispatch. I started playing aroudn with that in #247 but there are some design issues to work out (and I ran into the fact that Columns is a lie, which is now fixed)