JuliaAI / MLJLinearModels.jl

Generalized Linear Regressions Models (penalized regressions, robust regressions, ...)
MIT License
81 stars 13 forks source link

API to construct MLJ-proper model from standalone model #159

Open adienes opened 6 months ago

adienes commented 6 months ago
julia> enet = ElasticNetRegression()
GeneralizedLinearRegression{L2Loss, CompositePenalty}
  loss: L2Loss L2Loss()
  penalty: CompositePenalty
  fit_intercept: Bool true
  penalize_intercept: Bool false
  scale_penalty_with_samples: Bool true

julia> ElasticNetRegressor(enet)
ERROR: MethodError: no method matching ElasticNetRegressor(::GeneralizedLinearRegression{L2Loss, CompositePenalty})
ablaom commented 6 months ago

Perhaps @tlienart may like to differ, but my understanding is that ElasticNetRegression is a private constructor, ie has no associated public API.

Now ElasticNetRegressor is public. It constructs an object sorting hyperparameters (what MLJ calls a "model") and you use it like this:

using MLJBase # to get pretty printing

# default model:
julia> ElasticNetRegressor()
ElasticNetRegressor(
  lambda = 1.0, 
  gamma = 0.0, 
  fit_intercept = true, 
  penalize_intercept = false, 
  scale_penalty_with_samples = true, 
  solver = nothing)

# with a different `gamma` value:
julia> ElasticNetRegressor(gamma=0.1)
ElasticNetRegressor(
  lambda = 1.0, 
  gamma = 0.1, 
  fit_intercept = true, 
  penalize_intercept = false, 
  scale_penalty_with_samples = true, 
  solver = nothing)

Like other models, you can bind this with data in a machine, which you fit! to get learned parameters stored in the machine, and so forth. See this example

I presume that an ElasticNetRegression object gets created under the hood in fit!, but as I say, it is not exposed to the user, as far as I am aware.

adienes commented 6 months ago

it is documented as public API here https://juliaai.github.io/MLJLinearModels.jl/stable/api/#MLJLinearModels.ElasticNetRegression

and furthermore the fit and predict methods in MLJLinearModels.jl only work on ElasticNetRegression, not ElasticNetRegressor, and these methods are surely public API

unless this entire package should be considered internal?

ablaom commented 6 months ago

I stand corrected. I had forgotten there is also a "native" API. In that case I hope @tlienart can answer your question. I am only familiar with the MLJ interface.