JuliaTrustworthyAI / LaplaceRedux.jl

Effortless Bayesian Deep Learning through Laplace Approximation for Flux.jl neural networks.
https://juliatrustworthyai.github.io/LaplaceRedux.jl/
MIT License
38 stars 3 forks source link

Initial review of MLJ interface. #120

Open ablaom opened 1 week ago

ablaom commented 1 week ago

I'm posting this in response to the request at https://github.com/JuliaAI/MLJModels.jl/issues/571.

I can see a some work has gone into understanding MLJ's API requirements (and in understanding internals of MLJFlux).

I have not made an exhaustive review of the interface but list below some issues identified so far. Read point 4 first, as it is the more serious.

1. Form of predictions

Whenever possible, probabilistic predictions must take the form of a vector of distributions, where a "distribution" is something implementing Distributions.pdf and Random.rand (docs). So, instead of returning raw probabilities, the classifier should return a vector with element type UnivariateFinite (owned by CategoricalDistributions.jl). For example, here's what MLJFlux.NeuralNetworkClassifier predictions look like:

julia> predict(mach, rows=1:3)
3-element UnivariateFiniteVector{Multiclass{3}, String, UInt32, Float32}:
 UnivariateFinite{Multiclass{3}}(setosa=>0.342, versicolor=>0.349, virginica=>0.308)
 UnivariateFinite{Multiclass{3}}(setosa=>0.334, versicolor=>0.345, virginica=>0.322)
 UnivariateFinite{Multiclass{3}}(setosa=>0.337, versicolor=>0.345, virginica=>0.318)

Perhaps you mirror the code for that model, here.

Similarly, the regressor should return a vector of whatever Distributions distribution you are returning, e.g. a vector of Distributions.Normal, and not simply return parameters.

2. Table handling

I suspect there is something not generic about tables handling. If I train the classifier using data X, y = MLJBase.@load_iris I get an error, although training with X, y = make_moons() works fine. Getting the number of rows of a generic table (if that's the issue) has always been a bit of a problem, because the Tables.jl API was designed to include tables without length. I think the idea is that you should use DataAPI.nrow (row singular) for this, but I think MLJModelInterface.nrows or MLJBase.nrows (rows plural) are probably okay.

3. Metadata/traits

The load_paths are wrong (see correction below).

Your input and target types need some tweaking. For example, I'm getting warnings with the above data sets about the type of data when a I do machine(model, X, y). One problem is you have Finite in some places you probably want <:Finite, because Finite is a UnionAll type (parameterised, Finite{N}). See my suggestion below.

a Do you really support input's X with categorical features? (If you are you may be interested in the pending MLJFlux PR which adds entity embedding for categorical features for the non-image models. This might be more useful than static one-hot encoding, if that is what you do to handle categoricals.)

b Do you really support classification for non categorical targets y (you currently allow y to be Continuous)?

c Do you really intend to support regression with categorical targets y. What would that mean?

d So you really intend to exclude mixed data types in input X (some categorical, some continuous)?

e Do you handle OrderedFactor and Multiclass differently (as you probably should)? If not, perhaps you mean to restrict to Multiclass and have the user coerce OrderedFactor to Continuous (assuming you do not already do this under the hood).

Assuming the answers to a- d are: yes, no, no, no, here's my stab at a revised metadata declaration:

MLJBase.metadata_model(
    LaplaceClassification;
    input_scitype=Union{
        AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}} # matrix with mixed types
        MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
    },
    target_scitype=AbstractArray{<:MLJBase.Finite}}, # ordered factor or multiclass
    load_path="LaplaceRedux.LaplaceClassification",
)
# metadata for each model,
MLJBase.metadata_model(
    LaplaceRegression;
    input_scitype=Union{
        AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}} # matrix with mixed types
        MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
    },
     target_scitype=AbstractArray{MLJBase.Continuous},
    load_path="LaplaceRedux.MLJFlux.LaplaceRegression",
)

4. Use of private API (more serious)

The overloaded methods MLJFlux.shape, MLJFlux.build(::FluxModel, ...), MLJFlux.fitresult, and MLJFlux.train are not public API. They are simply abstractions that arose to try to remove some code duplication with the different models provided by MLJFlux. I am consequently reluctant to make this public. Indeed, the entity embedding PR referred to above breaks MLJ.fitresult, and future patch releases may break the API further. There may be a good argument for making this API public, but I feel this requires a substantial rethink. Indeed your own attempt to "hack" aspects of this API reveal the inadequacies: The fact that you feel the need to overload MLJFlux.train at all; the fact that the chain get's modified in train and not is some earlier stage, etc.

Unfortunately, I personally don't have the bandwidth for this kind of refactoring of MLJFlux any time soon. Your best option may simply be to cut and paste the MLJFlux code you need and have LaplaceRedux own independent versions of the private MLJFlux API methods referenced above. Alternatively, you could leave things as they are and live with breakages, as they occur. Not sure how keen I am on registering such a model, however. Perhaps we wait and see how stable the internal API winds up being.

5. Minor nomenclature point

For consistency with other MLJ models, I suggest LaplaceRegressor over LaplaceRegression and LaplaceClassifier over LaplaceClassification. Of course I understand you may have other reasons for the name choices.

pat-alt commented 1 week ago

Thank you so much for taking the time to write this up @ablaom, that's super helpful and much appreciated!

@pasq-cat has been working on this as part of Google Summer of Code, which he's finishing up in the coming days. Once that's done, we'll come back to this.

pat-alt commented 6 days ago

@ablaom have just come back to this, because I'm facing very similar issues with another one of our packages: https://github.com/JuliaTrustworthyAI/JointEnergyModels.jl?tab=readme-ov-file

The reason we targeted MLJFlux in both cases was that the underlying atomic models are Flux models and MLJFlux comes with useful functionality like the builders.

The fact that you feel the need to overload MLJFlux.train at all; the fact that the chain get's modified in train and not is some earlier stage, etc.

This was done here in order to make Laplace approximation (LA) part of the call to MLJ.fit. In light of your comments (especially point 4), I see two other possible solutions here @pasq-cat:

1. Interface MLJ directly (not through MLJFlux)

This is manageable (see. e.g. NeuroTreeModels. We lose the added functionality of MLJFlux, like builders and such, but I'm starting to think we'll have to live with that.

2. Give up on MLJ interface for now

Another option would be to give up on the custom models MLJ models and instead just run LA post-hoc. In this scenario, users could just rely on MLJFlux as they normally do to train conventional neural networks and then run LA. But this does not seem ideal, since we need the interface for https://github.com/JuliaTrustworthyAI/ConformalPrediction.jl/pull/125

I'm leaning towards interfacing MLJ directly, but curious to hear your thoughts.

Thanks again for your input here @ablaom and sorry for submitting this to MLJ a little prematurely.

pasq-cat commented 6 days ago

@pat-alt i guess we can interface mlj directly and in the model definition we leave a field for the flux chain provided by the user. it would have saved us a lot of time.

pasq-cat commented 6 days ago

I'm posting this in response to the request at JuliaAI/MLJModels.jl#571.

I can see a some work has gone into understanding MLJ's API requirements (and in understanding internals of MLJFlux).

I have not made an exhaustive review of the interface but list below some issues identified so far. Read point 4 first, as it is the more serious.

1. Form of predictions

Whenever possible, probabilistic predictions must take the form of a vector of distributions, where a "distribution" is something implementing Distributions.pdf and Random.rand (docs). So, instead of returning raw probabilities, the classifier should return a vector with element type UnivariateFinite (owned by CategoricalDistributions.jl). For example, here's what MLJFlux.NeuralNetworkClassifier predictions look like:

julia> predict(mach, rows=1:3)
3-element UnivariateFiniteVector{Multiclass{3}, String, UInt32, Float32}:
 UnivariateFinite{Multiclass{3}}(setosa=>0.342, versicolor=>0.349, virginica=>0.308)
 UnivariateFinite{Multiclass{3}}(setosa=>0.334, versicolor=>0.345, virginica=>0.322)
 UnivariateFinite{Multiclass{3}}(setosa=>0.337, versicolor=>0.345, virginica=>0.318)

Perhaps you mirror the code for that model, here.

Similarly, the regressor should return a vector of whatever Distributions distribution you are returning, e.g. a vector of Distributions.Normal, and not simply return parameters.

2. Table handling

I suspect there is something not generic about tables handling. If I train the classifier using data X, y = MLJBase.@load_iris I get an error, although training with X, y = make_moons() works fine. Getting the number of rows of a generic table (if that's the issue) has always been a bit of a problem, because the Tables.jl API was designed to include tables without length. I think the idea is that you should use DataAPI.nrow (row singular) for this, but I think MLJModelInterface.nrows or MLJBase.nrows (rows plural) are probably okay.

3. Metadata/traits

The load_paths are wrong (see correction below).

Your input and target types need some tweaking. For example, I'm getting warnings with the above data sets about the type of data when a I do machine(model, X, y). One problem is you have Finite in some places you probably want <:Finite, because Finite is a UnionAll type (parameterised, Finite{N}). See my suggestion below.

a Do you really support input's X with categorical features? (If you are you may be interested in the pending MLJFlux PR which adds entity embedding for categorical features for the non-image models. This might be more useful than static one-hot encoding, if that is what you do to handle categoricals.)

b Do you really support classification for non categorical targets y (you currently allow y to be Continuous)?

c Do you really intend to support regression with categorical targets y. What would that mean?

d So you really intend to exclude mixed data types in input X (some categorical, some continuous)?

e Do you handle OrderedFactor and Multiclass differently (as you probably should)? If not, perhaps you mean to restrict to Multiclass and have the user coerce OrderedFactor to Continuous (assuming you do not already do this under the hood).

Assuming the answers to a- d are: yes, no, no, no, here's my stab at a revised metadata declaration:

MLJBase.metadata_model(
    LaplaceClassification;
    input_scitype=Union{
        AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}} # matrix with mixed types
        MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
    },
    target_scitype=AbstractArray{<:MLJBase.Finite}}, # ordered factor or multiclass
    load_path="LaplaceRedux.LaplaceClassification",
)
# metadata for each model,
MLJBase.metadata_model(
    LaplaceRegression;
    input_scitype=Union{
        AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}} # matrix with mixed types
        MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
    },
     target_scitype=AbstractArray{MLJBase.Continuous},
    load_path="LaplaceRedux.MLJFlux.LaplaceRegression",
)

4. Use of private API (more serious)

The overloaded methods MLJFlux.shape, MLJFlux.build(::FluxModel, ...), MLJFlux.fitresult, and MLJFlux.train are not public API. They are simply abstractions that arose to try to remove some code duplication with the different models provided by MLJFlux. I am consequently reluctant to make this public. Indeed, the entity embedding PR referred to above breaks MLJ.fitresult, and future patch releases may break the API further. There may be a good argument for making this API public, but I feel this requires a substantial rethink. Indeed your own attempt to "hack" aspects of this API reveal the inadequacies: The fact that you feel the need to overload MLJFlux.train at all; the fact that the chain get's modified in train and not is some earlier stage, etc.

Unfortunately, I personally don't have the bandwidth for this kind of refactoring of MLJFlux any time soon. Your best option may simply be to cut and paste the MLJFlux code you need and have LaplaceRedux own independent versions of the private MLJFlux API methods referenced above. Alternatively, you could leave things as they are and live with breakages, as they occur. Not sure how keen I am on registering such a model, however. Perhaps we wait and see how stable the internal API winds up being.

5. Minor nomenclature point

For consistency with other MLJ models, I suggest LaplaceRegressor over LaplaceRegression and LaplaceClassifier over LaplaceClassification. Of course I understand you may have other reasons for the name choices.

Thanks for the review and feedback, It's really helpful to have some direct input to improve this.

1) i have added support to distributions but it was decided later that they should be optional, so we never changed this back. @pat-alt i still think predict should only return distributions objects... Managing all the options in predictions would be easier and we could remove the ret_distr parameters. 2)) didn't realize there was this problem. 3)eh i missed this one. the metadata were left by the previous author when there was a single model for both classification and regression, that's why there were both cases. I initially split them in two but between one branch and the others, I never went back to change them. I should have updated them but i guess we missed to notice them. 4) eh yes, i had zero experience with MLJFlux, so i tried to make it work but i didn't understand it completely.

ablaom commented 6 days ago

My sense is that the degree of code complexity you need to add to make MLJFlux.jl work is not worth any extra functionality you buy into, and so interfacing MLJ directly may be better. Of course you should feel free to mirror whatever is useful to you from MLJFlux.jl. Let me know what you decide.

That said, I would support a redesign of MLJFlux's "internal API" that accommodates a wider range of models. I just don't have the resources to do this on my own. It sounds like you and your team would be in a good place to suggest such a design, if you likewise have the resources at some point.

ablaom commented 6 days ago

To clarify an earlier point regarding the form of predictions. It is not absolutely necessary that predictions be distributions. I noticed that prediction is VERY slow at present, so am guessing you are doing some sort of sampling to get the parametric forms (sorry I didn't research LA). An option in MLJ is to just directly return a vector of "sampleable" objects (objects implementing just rand). One downside is that MLJ doesn't currently provide any metrics with this weaker assumption on the form of the proxy for the target prediction. If you go this route, you'll want to include plenty of documentation, as there aren't other registered models like this at present.

I have also been making the usual assumtion that prediction components are not correlated. That is, the nth value of the probabilistic prediction predict(model, fitresult, X) depends only on the nth obervation of X. If that is not the case, then see here.

pasq-cat commented 5 days ago

@pat-alt pls correct me if i am wrong, but no there is no sampling involved and the prediction are not correlated, the hessian is found during the training phase, so i guess it's a defect of our mljflux implementation.

ablaom commented 4 days ago

FYI

This timing is fairly reproducible on my machine:

X = MLJBase.table(rand(Float32, 100, 3));
y = coerce(rand("abc", 100), Multiclass);
model = LaplaceClassification();
fitresult, _, _ = MLJBase.fit(model, 0, X, y);
MLJBase.predict(model, fitresult, X);
@time MLJBase.predict(model, fitresult, X)
# 30.179502 seconds (58.51 k allocations: 8.287 GiB, 2.13% gc time)

For MLJFlux.NeuralNetworkClassifier:

julia> @time MLJBase.predict(model, fitresult, X)
  0.000275 seconds (1.33 k allocations: 83.375 KiB)
pasq-cat commented 4 days ago

@ablaom i went back to see the theory behind the LAplace approximation. For each new element x, it is necessary to compute again the jacobians, so maybe this is the reason why it takes longer in the inference phase respect to a standard neural network image

pat-alt commented 3 days ago

Thanks very much for flagging this @ablaom.

The Jacobian computation is indeed a bottleneck in forward passes but it's not causing this issue. Adapting @ablaom's example from above:

using LaplaceRedux
using MLJBase

X = MLJBase.table(rand(Float32, 100, 3));
y = coerce(rand("abc", 100), Multiclass);
model = LaplaceClassification();
fitresult, _, _ = MLJBase.fit(model, 0, X, y);
la = fitresult[1];
Xmat = matrix(X) |> permutedims;

# Single test sample:
Xtest = Xmat[:,1:10];
Xtest_tab = MLJBase.table(Xtest');
MLJBase.predict(model, fitresult, Xtest_tab);       # warm up
LaplaceRedux.predict(la, Xmat);                     # warm up

Generating predictions using our MLJ interface vs. our default predict method leads to wildly different computation times:

julia> @time MLJBase.predict(model, fitresult, Xtest_tab);
  1.806758 seconds (5.86 k allocations: 848.624 MiB, 0.78% gc time)

julia> @time LaplaceRedux.predict(la, Xtest);
  0.189871 seconds (4.71 k allocations: 86.994 MiB)

julia> @time glm_predictive_distribution(la, Xtest);
  0.189886 seconds (4.71 k allocations: 86.994 MiB)

Curiously, it takes pretty much exactly 10x as long using the MLJ interface and here I've chosen 10 test samples. Trying it with 50 test samples seems to confirm that predict through our MLJ interface seems to scale proportionately with the number of inputs:

julia> Xtest = Xmat[:,1:50];

julia> Xtest_tab = MLJBase.table(Xtest');

julia> MLJBase.predict(model, fitresult, Xtest_tab);       # warm up

julia> LaplaceRedux.predict(la, Xmat);                     # warm up

julia> @time MLJBase.predict(model, fitresult, Xtest_tab);
  8.926642 seconds (29.26 k allocations: 4.144 GiB, 0.76% gc time)

julia> @time LaplaceRedux.predict(la, Xtest);
  0.260523 seconds (22.60 k allocations: 101.032 MiB)

julia> @time glm_predictive_distribution(la, Xtest);
  0.252946 seconds (22.60 k allocations: 101.032 MiB)

The problem is the map over samples here which I've now addressed in #123. That's my bad, should have spotted this when reviewing your PR @pasq-cat.

pasq-cat commented 3 days ago

@pat-alt eh i was modifying exactly this part yesterday. i was trying


 X_vec = [Vector(row) for row in eachrow(Xnew)]

    predictions = [
        LaplaceRedux.predict(
            la,
            x;
            link_approx=model.link_approx,
            predict_proba=model.predict_proba,
            ret_distr=model.ret_distr,
        ) for x in X_vec
    ]

because i remembered that LaplaceRedux.predict accepted vectors as input