JuliaAI / MLJBase.jl

Core functionality for the MLJ machine learning framework
MIT License
160 stars 45 forks source link

Address some predict/transform type instabilities #969

Closed ablaom closed 5 months ago

ablaom commented 6 months ago

(edited) This PR addresses a type instability for operations (predict, transform, etc) acting on machines, as identified in #959 (although this PR does not resolve the particular issue there).

I admit it is not clear to me that the performance gains here are likely to significantly benefit many use cases. But having done the work to identify these instabilities, I don't see harm in addressing them.

The type instability is not difficult to address in the case of machines attached to ordinary models, by annotating a currently abstract type in the Machine struct. However, in the special case of a machine attached to a symbolic model (which appear exclusively in learning networks), the type instability remains (and looks difficult to remove).

Benchmarks

In 69 regression models, we compared the "high level" predict(::Machine, ...) method with the "low level" predict(::Model, ...) method (edit plus reformat(::Model, ...)) implemented by third party model providers. The benchmark code is hidden below:

``` using MLJTestInterface, MLJModels, MLJBase using Tables using Random using BenchmarkTools using Statistics import DataFrames import MLJModelInterface as MMI const MODELS = models() do m !(m.package_name in ["MLJText"]) && AbstractVector{Continuous} <: m.target_scitype && m.is_supervised end # This is a way to load all needed model code: MLJTestInterface.test(MODELS, mod=@__MODULE__, level=1, throw=true) rng = Random.MersenneTwister(0) Xmat = randn(rng, 30, 3) X = Tables.table(Xmat) y = @. cos(Xmat[:, 1] * 2.1 - 0.9) * Xmat[:, 2] - Xmat[:, 3] function predict_low(model, fitresult, X) Xraw = MMI.reformat(model, X) MMI.predict(model, fitresult, Xraw...) end stats = [] for m in MODELS print("\rBenchmarking $(m.name) $(m.package_name).") model = eval(:(@load $(m.name) pkg=$(m.package_name) verbosity=0))() mach = machine(model, X, y) fit!(mach, verbosity=0) fitresult = mach.fitresult b_high = @benchmark predict($mach, $X) b_low = @benchmark predict_low($model, $fitresult, $X) slow_down = median(b_high.times)/median(b_low.times) bloat = b_high.allocs/b_low.allocs push!(stats, (; model=m.name, pkg=m.package_name, slow_down, bloat)) print(" Done. ") end @show length(MODELS) #length(MODELS) = 69 stats = DataFrames.DataFrame([stats...]) filter(stats) do row row.slow_down > 1.75 || row.bloat > 2.0 end ```

In the tables below:

Only models with slow_down > 1.75 or bloat > 2 are reported.

Before this PR


#  Row │ model                           pkg                           slow_down  bloat
#      │ String                          String                        Float64    Float64
# ─────┼──────────────────────────────────────────────────────────────────────────────────
#    1 │ ConstantRegressor               MLJModels                       9.67007    4.0
#    2 │ DeterministicConstantRegressor  MLJModels                      14.0756     4.0
#    3 │ ElasticNetRegressor             MLJLinearModels                 4.41319    2.0
#    4 │ HuberRegressor                  MLJLinearModels                 3.93931    2.0
#    5 │ LADRegressor                    MLJLinearModels                 3.97205    2.0
#    6 │ LassoRegressor                  MLJLinearModels                 4.20191    2.0
#    7 │ LinearRegressor                 MLJLinearModels                 3.95137    2.0
#    8 │ LinearRegressor                 MultivariateStats               4.92059    2.5
#    9 │ PLSRegressor                    PartialLeastSquaresRegressor    2.11492    1.375
#   10 │ QuantileRegressor               MLJLinearModels                 4.15291    2.0
#   11 │ RidgeRegressor                  MLJLinearModels                 3.99158    2.0
#   12 │ RidgeRegressor                  MultivariateStats               5.03699    2.5
#   13 │ RobustRegressor                 MLJLinearModels                 3.95992    2.0

After this PR:

#  Row │ model              pkg        slow_down  bloat
#      │ String             String     Float64    Float64
# ─────┼──────────────────────────────────────────────────
#    1 │ ConstantRegressor  MLJModels    1.61401      3.0

Note that machines serialised using #master cannot be deserialised after this PR. But I don't consider this triggers a breaking release.

To do:

ablaom commented 5 months ago

Doesn't look like there's a significant difference.

Before this PR:

julia> @time_imports import MLJBase
    413.2 ms  MLJBase 23.21% compilation time

After this PR:

@time_imports import MLJBase
    437.3 ms  MLJBase 22.21% compilation time