Closed adienes closed 9 months ago
ok so I won't have the time for a full in depth review just yet but here are some thoughts already:
fit_gram
function, that should be taken care of by the solver
the logic being that the solver indicates which specific fit method to use which is what you want to do in your casegramproxgrad
file; it should be possible to just extend the existing one (code duplication is something to be avoided)@tlienart thank you for the comments and apologies for the delay
I made the call signature like this
θ_gram = fit(enr; data=(; XX, Xy, n), solver=FISTA(gram=true))
I did not want to pun the arguments of X
and y
so I added a kwarg data
. the reasons are
XX
and Xy
gram=true
even when passing in X
and y
and then the kernels will be precomputed by the library. this will also allow for fit_intercept which I left unimplemented herefor X
of size (100000, 200)
, the performance is nearly 200x
better for both runtime and memory usage, including the cost to create X'X
and X'y
julia> @benchmark lin.fit(enet, X, y; data, solver=FISTA(gram=true, max_iter=5000))
BenchmarkTools.Trial: 257 samples with 1 evaluation.
Range (min … max): 19.109 ms … 21.066 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 19.522 ms ┊ GC (median): 1.19%
Time (mean ± σ): 19.524 ms ± 254.764 μs ┊ GC (mean ± σ): 0.77% ± 0.60%
▂▂ ▂▃ ▁ ▂ ▆▅ ▆ ▅█▁▆▃▃▂▃▂
▄▆▄▃▄▄██▇██▃█▄▆█▇▅███▅█▅██████████▅▄▇▃▅▁▃▁▃▁▃▁▁▃▃▃▁▅▅▄▃▅▁▁▁▃ ▄
19.1 ms Histogram: frequency by time 20.1 ms <
Memory estimate: 27.34 MiB, allocs estimate: 44058.
julia> @benchmark lin.fit(enet, X, y; data, solver=FISTA(gram=false, max_iter=5000))
BenchmarkTools.Trial: 1 sample with 1 evaluation.
Single result which took 7.386 s (1.65% GC) to evaluate,
with a memory estimate of 4.70 GiB, over 50418 allocations.
Attention: 1 lines
in your changes are missing coverage. Please review.
Comparison is base (
0b48318
) 96.11% compared to head (79e3628
) 96.13%.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
was wondering if we could dust this off. is the API ok? what are the main concerns
will wait for CI to pass then merge & do a minor release with it, thanks again
Ref: https://github.com/JuliaAI/MLJLinearModels.jl/issues/145
This is a very basic initial implementation. I only checked the simplest possible case which is
Lasso
with a single target, no intercept, etc. etc. but the following will work as expected:Please suggest any cleanups that need to be made to make this robust addition to package. I know it is a little quickly done but I wasn't sure what to focus on to clean
Need to still
NotImplementedErrors
) or make implementation more generic, right now the functions will accept more than they can actually handle