JuliaAI / MLJLinearModels.jl

Generalized Linear Regressions Models (penalized regressions, robust regressions, ...)
MIT License
80 stars 13 forks source link

first pass at Gramian training for OLS #146

Closed adienes closed 9 months ago

adienes commented 1 year ago

Ref: https://github.com/JuliaAI/MLJLinearModels.jl/issues/145

This is a very basic initial implementation. I only checked the simplest possible case which is Lasso with a single target, no intercept, etc. etc. but the following will work as expected:

# X, y some data
lasso = LassoRegression(0.01, fit_intercept=false)
fit_gram(lasso, X'X, X'y; n=100)

Please suggest any cleanups that need to be made to make this robust addition to package. I know it is a little quickly done but I wasn't sure what to focus on to clean

Need to still

  1. decide what the right signature is to access training on XX, Xy
  2. either make inputs more strict (with loud NotImplementedErrors) or make implementation more generic, right now the functions will accept more than they can actually handle
  3. add test cases
tlienart commented 1 year ago

ok so I won't have the time for a full in depth review just yet but here are some thoughts already:

  1. good that it's working and doing what you expected
  2. there's no need for an additional fit_gram function, that should be taken care of by the solver the logic being that the solver indicates which specific fit method to use which is what you want to do in your case
  3. there's no need for a separate gramproxgrad file; it should be possible to just extend the existing one (code duplication is something to be avoided)
  4. ignore LinearMap for now as you're actually not using that and it would therefore needlessly add an external dependency
  5. there needs to be tests, you can run those by comparing against a non XtX solver, or against a sklearn solver.
adienes commented 1 year ago

@tlienart thank you for the comments and apologies for the delay

I made the call signature like this

θ_gram = fit(enr; data=(; XX, Xy, n), solver=FISTA(gram=true))

I did not want to pun the arguments of X and y so I added a kwarg data. the reasons are

  1. when using from MLJ it would be surprising if accidentally used cross validation / fold splitting on XX and Xy
  2. it might be desirable to support gram=true even when passing in X and y and then the kernels will be precomputed by the library. this will also allow for fit_intercept which I left unimplemented here
adienes commented 1 year ago

for X of size (100000, 200) , the performance is nearly 200x better for both runtime and memory usage, including the cost to create X'X and X'y

julia> @benchmark lin.fit(enet, X, y; data, solver=FISTA(gram=true, max_iter=5000))
BenchmarkTools.Trial: 257 samples with 1 evaluation.
 Range (min … max):  19.109 ms …  21.066 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     19.522 ms               ┊ GC (median):    1.19%
 Time  (mean ± σ):   19.524 ms ± 254.764 μs  ┊ GC (mean ± σ):  0.77% ± 0.60%

        ▂▂ ▂▃ ▁  ▂   ▆▅ ▆ ▅█▁▆▃▃▂▃▂                             
  ▄▆▄▃▄▄██▇██▃█▄▆█▇▅███▅█▅██████████▅▄▇▃▅▁▃▁▃▁▃▁▁▃▃▃▁▅▅▄▃▅▁▁▁▃ ▄
  19.1 ms         Histogram: frequency by time         20.1 ms <

 Memory estimate: 27.34 MiB, allocs estimate: 44058.

julia> @benchmark lin.fit(enet, X, y; data, solver=FISTA(gram=false, max_iter=5000))
BenchmarkTools.Trial: 1 sample with 1 evaluation.
 Single result which took 7.386 s (1.65% GC) to evaluate,
 with a memory estimate of 4.70 GiB, over 50418 allocations.
codecov-commenter commented 1 year ago

Codecov Report

Attention: 1 lines in your changes are missing coverage. Please review.

Comparison is base (0b48318) 96.11% compared to head (79e3628) 96.13%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## dev #146 +/- ## ========================================== + Coverage 96.11% 96.13% +0.01% ========================================== Files 22 22 Lines 876 905 +29 ========================================== + Hits 842 870 +28 - Misses 34 35 +1 ``` | [Files](https://app.codecov.io/gh/JuliaAI/MLJLinearModels.jl/pull/146?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaAI) | Coverage Δ | | |---|---|---| | [src/fit/default.jl](https://app.codecov.io/gh/JuliaAI/MLJLinearModels.jl/pull/146?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaAI#diff-c3JjL2ZpdC9kZWZhdWx0Lmps) | `88.46% <100.00%> (+2.74%)` | :arrow_up: | | [src/fit/proxgrad.jl](https://app.codecov.io/gh/JuliaAI/MLJLinearModels.jl/pull/146?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaAI#diff-c3JjL2ZpdC9wcm94Z3JhZC5qbA==) | `95.74% <100.00%> (+0.39%)` | :arrow_up: | | [src/fit/solvers.jl](https://app.codecov.io/gh/JuliaAI/MLJLinearModels.jl/pull/146?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaAI#diff-c3JjL2ZpdC9zb2x2ZXJzLmps) | `100.00% <ø> (ø)` | | | [src/glr/d\_l2loss.jl](https://app.codecov.io/gh/JuliaAI/MLJLinearModels.jl/pull/146?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaAI#diff-c3JjL2dsci9kX2wybG9zcy5qbA==) | `100.00% <100.00%> (ø)` | | | [src/glr/utils.jl](https://app.codecov.io/gh/JuliaAI/MLJLinearModels.jl/pull/146?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaAI#diff-c3JjL2dsci91dGlscy5qbA==) | `92.30% <100.00%> (+1.39%)` | :arrow_up: | | [src/mlj/classifiers.jl](https://app.codecov.io/gh/JuliaAI/MLJLinearModels.jl/pull/146?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaAI#diff-c3JjL21sai9jbGFzc2lmaWVycy5qbA==) | `100.00% <ø> (ø)` | | | [src/utils.jl](https://app.codecov.io/gh/JuliaAI/MLJLinearModels.jl/pull/146?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaAI#diff-c3JjL3V0aWxzLmps) | `96.96% <91.66%> (-0.74%)` | :arrow_down: |

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

adienes commented 9 months ago

was wondering if we could dust this off. is the API ok? what are the main concerns

tlienart commented 9 months ago

will wait for CI to pass then merge & do a minor release with it, thanks again