JuliaAI / MLJ.jl

A Julia machine learning framework
https://juliaai.github.io/MLJ.jl/
Other
1.79k stars 158 forks source link

Transformers that need to see target (eg, recursive feature elimination) #874

Closed ablaom closed 4 months ago

ablaom commented 2 years ago

A number of feature-reduction strategies only make sense in the context of a supervised learning task because they must consult a target variable when trained. For example, one might wants to drop features which correlate poorly with the target. In fact all but the first of sklearn's feature selectors are of this kind.

At the level of the basic model API, a transformer (or any other model) can specify any number of arguments to be used in training. So there is nothing wrong with a transformer with a fit method like

MLJModelInterface.fit(model::MyTransformer, verbosity, X, y) = ...

There is now a trait defined in MLJModelInterface to explicitly articulate the acceptable fit signatures (up to scitype). For any model type that subtypes Unsupervised this falls back to a single argument where the scitype must coincide with input_scitype(model). So for transformers that needs the target in training, you would override the trait with a declaration such as:

MLJModelInterface.fit_data_scitype(M::Type{<:MyTransformer}) =
    Tuple{input_scitype(M), target_scitype(M)}

and be sure to declare a target_scitype, just as you would for a supervised model. That should do it.

It may be that some argument checks for machines have to be tweaked in MLJBase (edit now done) but this should be very easy and essentially non-breaking.

Most happy to provide support to anyone wishing to implement such transformers.

ablaom commented 2 years ago

cc @pazzo83

pazzo83 commented 2 years ago

Would it make sense to create a new model type for this? It's like a supervised transformer.

ablaom commented 2 years ago

I think the general consensus is a move away from types to traits. I'm afraid the discussions are a little scattered. See, for example https://github.com/alan-turing-institute/MLJ.jl/issues/852#issuecomment-987655527.

We could add a trait for "supervised transformers" but perhaps this is unnecessary as the fit_data_scitype (together with subtyping <:Unsupervised - which ultimately might be encoded as is_supervised(mode) = false) essentially captures the whole behaviour, no?

pazzo83 commented 2 years ago

So would my transformer that relies on a target subtype Unsupervised? I have tried that with overriding the definition of fit_data_scitype (as above), but I'm not sure what type to pass in where you have MyTransformer. If i create an abstract type that subtypes Unsupervised, and then have my transformer subtype that, then I still run into the check on machines:

ArgumentError: `Unsupervised` models should have one training argument, except `Static` models, which have none. Use  `machine(model, X; ...)` (usual case) or `machine(model; ...)` (static case). 

I guess I'm not familiar enough with the inner workings of the API yet to know whether that check needs to be modified, but could you expand a bit on what you mean by using traits to allow for this functionality?

pazzo83 commented 2 years ago

I was able to partially get this working with the following patches (after I declared my own abstract type that my transformers subtype):

abstract type TargetTransformer <: MMI.Unsupervised end

MLJModelInterface.fit_data_scitype(M::Type{<:TargetTransformer}) =
    Tuple{input_scitype(M), target_scitype(M)}

MLJBase.check(model::TargetTransformer, args...; full=false) = MLJBase.check_supervised(model, full, args...)

MLJBase.warn_scitype(model::TargetTransformer, X, y) =
    "The scitype of `y`, in `machine(model, X, y, ...)` "*
    "is incompatible with "*
    "`model=$model`:\nscitype(y) = "*
    "$(MLJBase.elscitype(y))\ntarget_scitype(model) "*
    "= $(MLJBase.target_scitype(model))."

However, when put into a pipeline, it no longer works - it seems because it is still unsupervised the target is not getting passed through (see here: https://github.com/JuliaAI/MLJBase.jl/blob/dev/src/composition/models/pipelines.jl#L72).

This might be kind of hacky and I think you are suggesting something a bit different?

ablaom commented 2 years ago

@pazzo83 Thanks for looking at this.

This isn't far from what I was imagining. Only, rather than introduce a new abstract type, I'd overload fit_data_scitype case-by-case. That is, each new model implementation for a concrete "supervised" transformer type MyTransformer includes the declaration

MLJModelInterface.fit_data_scitype(M::Type{<:MyTransformer}) =
    Tuple{input_scitype(M), target_scitype(M)}

Then, to fix the the type checking, modify the existing MLJBase.check(::Unsupervised, ...) method so as to catch transformers that legitimately require length(mach.args) == 2, because you see that length(fit_data_signature(model)) is a 2-tuple.

Ditto MLJBase.warn_scitype(::Unsupervised, ...).

This is all a little tricky as we want flexibility of design, but want also catch users' unintentional mistakes with informative errors. Warnings are better than errors here, but even warnings should be thrown only as necessary.

And we'd need to test the changes with a dummy "supervised" transformer in tests.

However, when put into a pipeline, it no longer works

No. That will require a bit more work. Nevertheless, I'm pretty sure one could include these transformers in custom composite models (exported learning networks) without issues. So they would be useful even without the pipeline enhancement.

I'd support a PR that fixes the checks without worrying about pipelines just yet. Would also be great to have an actual "supervised" transformer implementation to try this out on. Have you already started on something?

pazzo83 commented 2 years ago

Thanks for the feedback! I can definitely put together a PR for this - I have some local code I've been working on so I can incorporate your feedback and go from there.

pazzo83 commented 2 years ago

I've been looking at this over the last couple of days based on the feedback here: https://github.com/JuliaAI/MLJBase.jl/pull/705

Would it work if we removed all the various check_* methods and simply kept this method:

function check(model::Model, args...; full=false)
    nowarns = true

    F = fit_data_scitype(model)
    (F >: Unknown || F >: Tuple{Unknown} || F >: NTuple{<:Any,Unknown}) &&
        return true

    S = Tuple{elscitype.(args)...}
    if !(S <: F)
        @warn warn_generic_scitype_mismatch(S, F)
        nowarns = false
    end
end

I got it working if I rewrote the line: S = Tuple{elscitype.(args)...} to just S = Tuple{scitype.(args)...} Then, it just checks if the model you are trying to use matches the scitype signature it was defined to have. I think that's what we ultimately want, right?

ablaom commented 2 years ago

Yes! That is what I think we should do. And it indeed looks like you found a bug with the elscitype <-> scitype business - good catch.

I would expand the return value of warn_generic_scitype_mismatch(S, F) along the lines previously suggested and copied below:

"The number and/or types of data arguments do not match what the specified model supports. Commonly, but non exclusively, supervised models are constructed using the syntax machine(model, X, y) or machine(model, X, y, w) while most other models with machine(model, X). Here X are features, y a target, and w sample or class weights. In general, data in machine(model, data...) must satisfy scitype(data) <:MLJ.fit_data_scitype(model)unless the right-hand side isUnknown`. "

Thanks for getting back go this!

ablaom commented 2 years ago

Just a note that scitype checks have now (MLJBase 18.0) been relaxed to allow transformers that need a target.

ablaom commented 4 months ago

Resolved.