JuliaML / TableTransforms.jl

Transforms and pipelines with tabular data in Julia
https://juliaml.github.io/TableTransforms.jl/stable
MIT License
102 stars 15 forks source link

API design discussion #220

Closed CameronBieganek closed 8 months ago

CameronBieganek commented 8 months ago

Summary

It seems to me that revert and external caches can be completely avoided by using the usual fit and predict approach for machine learning pipelines. And this can be done in a functional, immutable fashion.

Discussion

I skimmed a couple sections in your book. Here's a quote from Chapter 7:

A very common workflow in geospatial data science consists of:

  1. Transforming the data to an appropriate sample space for geostatistical analysis
  2. Doing additional modeling to predict variables in new geospatial locations
  3. Reverting the modeling results with the saved pipeline and cache

I also skimmed through the mineral deposits example in Chapter 12 to see this approach in action. I might have misunderstood something, but it seems like the overall goal is to fit a Kriging model to some data, and then use that Kriging model to make predictions at the various points (or centroids) on a grid.

The approach taken by TableTransforms (and GeoStats) is to apply a pipeline and save an external cache, make a prediction, and then use the cache to revert.

I might be getting the details below a bit wrong, because GeoStats probably has special handling for the spatial coordinates (which are input features to the Kriging model), but overall it seems like the traditional machine learning approach can be applied, something like this:

model = fit(untrained_model, data, target)
interp = predict(model, grid)

...where untrained_model is a pipeline that has some sort of Kriging/Interpolation model at the end of it. This way the user doesn't need to handle the cache object themselves. This can still be done with a functional, immutable approach. For example, Center might look something like this:

struct Center{T}
    means::T
end

Center() = Center(nothing)

function fit(c::Center, data)
    # means = ...
    Center(means)
end

Or, of course, you could use a Union{T, Nothing}:

struct Center
    means::Union{Vector{Float64}, Nothing}
end
juliohm commented 8 months ago

Hi @CameronBieganek ! Thanks for sharing these ideas. We carefully designed our transforms so that they only store the hyperparameters, which are static in memory and trivial to pass around in parallel jobs. For example, we have the PCA transform that only stores the output dimension and the column names. The cache is much more expensive to pass around as it contains the principal component basis.

Another issue with the design you suggested is that it glues the transform hyperparameters with the specific data set used in the apply step. This was one of the main reasons we did not advance with MLJ.jl nor FeatureTransforms.jl. If you take a look into the mineral deposit example in the book, we fit the Kriging model to the drill hole samples (set of points along a trajectory) and predict on the Cartesian grid. Saving information about the trajectory doesn't help in the revert step. We need to save caches that are very complicated sometimes.

Besides the two issues above, we are aiming for a more general approach to pipeline optimization, which may include neural networks. If we had to save the weights of a neural network model inside the struct that is passed around as a transform, we would be screwed.

I will close the issue, but feel free to continue the discussion here or on Zulip.

CameronBieganek commented 8 months ago

I'm not really convinced by your argument about parallel processing. Note that in the example I gave above the transforms are still immutable, and transforms/models that have not yet been fit have all their fitted parameter fields set to nothing, so they are still lightweight to pass around.

I now understand why you have revert. In the mineral deposit example, the transforms are all applied to the target values, not the feature values. So revert is analogous to TransformedTargetRegressor in scikit-learn and TransformedTargetModel in MLJ. The TransformedTargetRegressor and TransformedTargetModel types are admittedly a little clunky, but they get the job done, and they have the advantage that the user can stick with the usual fit and predict ML workflow without having to manually handle the cache object. Perhaps there is a cleaner API somewhere between the sklearn approach and the current approach in TableTransforms.jl.

Taking a higher level perspective, the main reason I don't like cache is because it makes the API more verbose and complicated (in my opinion), and because it leaks implementation details to the user. There is no reason why the user should have to manually handle and worry about cache objects.

Apologies, but now I'm going to move on to a more general design discussion. Hopefully you can take it in the spirit of trying to improve the ML ecosystem. I think TableTransforms.jl has many nice aspects, but I think there are also some elements of the design that could be improved. I could open separate issues, but this is probably more of an open-ended design discussion. (I don't have Zulip or Slack.)

Taking a look at this line of code from the book:

interp = samples |> InterpolateNeighbors(blocks, models...)

That's a rather odd table transform. It takes a table with 2,000 rows and returns a table with ~100,000 rows. I think a better conceptual model for what is going on is that there are two steps: fit a model pipeline with samples, and then use that pipeline to predict on the new data blocks. Also, introducing more verbs (lowercase function calls) into the workflow makes it more clear where the action is happening and what kind of action is happening.

It also might be helpful to make a distinction between table queries and ML/statistics/data science transformations (which normally preserve the number of rows). Right now it is easy to make an invalid ML pipeline:

model = Filter(row -> row.a > 10) → RandomForestRegressor()

It's ok to filter the training data, but it is not ok to filter the prediction data. When you put a model into production, you usually need to make a prediction on every input observation. Even before production, in the context of a train/test split, it could so happen that you filter out all the data in the test set. At a minimum, your specified train-test split fraction will not hold if you apply the filtering as part of the model (after the train-test split). (I know I'm extrapolating your design by attaching an ML model at the end of the pipe, but that is the typical ML approach.)

This is a bit of bikeshedding, but apply and reapply are very generic verbs. I prefer fit, transform, and fit_transform because they are more specific and because they are already well established terms in the ML community. (At least within the scikit-learn community, but that accounts for a pretty large percentage of ML practice.)


I originally typed the following up, but then I realized that you can do the exact same thing in scikit-learn. It's just that the documentation makes it clear that the normal pattern is to attach an estimator to the end of a pipeline, and then provide the fit and predict methods with the raw (untransformed) X and y data.

Making transforms callable is cute, but it makes it easy for users to make mistakes. One can easily use an untrained pipeline for both training and prediction, like this:

transform = ZScore() → EigenAnalysis(:V)
modelfit = fit(model, transform(Xtrain), ytrain)
ŷ = predict(modelfit, transform(Xtest))
juliohm commented 8 months ago

Thank you for the feedback. Below are specific comments.

Taking a higher level perspective, the main reason I don't like cache is because it makes the API more verbose and complicated (in my opinion), and because it leaks implementation details to the user. There is no reason why the user should have to manually handle and worry about cache objects.

Maybe you are misunderstanding the goal of the cache in the revert step. It is not something that sk-learn nor MLJ support as far as I know. It is about "undoing" pipelines, it has nothing to do with fit/predict. Our fit/predict is quite clean here. It consists of creating the transform object on "training" data and calling it as a functor on "test" data. We are still brainstorming this API in the Learn transform in StatsLearnModels.jl, which adheres to TableTransforms.jl.

Taking a look at this line of code from the book:

interp = samples |> InterpolateNeighbors(blocks, models...)

That's a rather odd table transform. It takes a table with 2,000 rows and returns a table with ~100,000 rows. I think a better conceptual model for what is going on is that there are two steps: fit a model pipeline with samples, and then use that pipeline to predict on the new data blocks. Also, introducing more verbs (lowercase function calls) into the workflow makes it more clear where the action is happening and what kind of action is happening.

Try to think outside the fit/predict box of other frameworks. The InterpolateNeighbors transforms is a geospatial transform, it doesn't have to do with statistical learning models where you have a "train" table and a "test" table of features. It only takes a geospatial domain (blocks) and performs interpolation.

Your entire discussion should probably be narrowed down to our Learn transform, which is more related to what MLJ and sk-learn can do.

It also might be helpful to make a distinction between table queries and ML/statistics/data science transformations (which normally preserve the number of rows). Right now it is easy to make an invalid ML pipeline:

model = Filter(row -> row.a > 10) → RandomForestRegressor()

It's ok to filter the training data, but it is not ok to filter the prediction data.

Thanks but I disagree. The idea of the TransformsBase.jl api is to be able to combine all sorts of api agnostically. We don't need categorization to create sophisticated pipelines involving geometric, statistical, cleaning, etc transforms.

Also, your argument here and in other parts of the text below is not very good. You are saying something like: "users can do messy things, so we should limit the power of our interface to only handle a subset of features"

This is a bit of bikeshedding, but apply and reapply are very generic verbs. I prefer fit, transform, and fit_transform because they are more specific and because they are already well established terms in the ML community. (At least within the scikit-learn community, but that accounts for a pretty large percentage of ML practice.)

As far as I understand it, apply and reapply have a different purpose than fit and fit_transform in other frameworks. Also, the latter are jargon, and our transforms go beyond ML transforms only.

CameronBieganek commented 8 months ago

Maybe you are misunderstanding the goal of the cache in the revert step. It is not something that sk-learn nor MLJ support as far as I know. It is about "undoing" pipelines, it has nothing to do with fit/predict.

Yes, they do support it, as I mentioned above. The revert functionality is handled by TransformedTargetRegressor in sklearn and TransformedTargetModel in MLJ. I admit that those types do not provide the prettiest interface ever, but they do have some advantages (which I also mentioned above).

Aside from undoing target transformations, which sklearn and MLJ provide, I'm not sure what other use case there is for "undoing" a pipeline.

Try to think outside the fit/predict box of other frameworks.

I'm perfectly capable of thinking outside the fit/predict box. I just happen to think that fit/predict is a better abstraction for what is going on here. Eventually you end up contorting yourself to adhere to the "Everything is a table transform" philosophy.

The InterpolateNeighbors transforms is a geospatial transform, it doesn't have to do with statistical learning models where you have a "train" table and a "test" table of features. It only takes a geospatial domain (blocks) and performs interpolation.

Perhaps geospatial scientists are less interested in testing the generalization error of their models than data scientists are. You can and probably should split your mineral deposit samples into train and test sets so that you can empirically estimate how accurate your Kriging model is.

Scikit-learn does have a Kriging model, which uses fit and predict methods. I find their interface more intuitive.

Also, your argument here and in other parts of the text below is not very good. You are saying something like: "users can do messy things, so we should limit the power of our interface to only handle a subset of features"

A good API helps the user avoid errors. The power of a table query language like SQL comes from its limited scope. Here's a quote from "Database Systems: The Complete Book" (page 38) in answer to the question "Why do we need a special query language?":

The surprising answer is that relational algebra is useful because it is less powerful than C or Java. That is, there are computations one can perform in any conventional language that one cannot perform in relational algebra. An example is: determine whether the number of tuples in a relation is even or odd. By limiting what we can say or do in our query language, we get two huge rewards — ease of programming and the ability of the compiler to produce highly optimized code.

juliohm commented 8 months ago

Perhaps geospatial scientists are less interested in testing the generalization error of their models than data scientists are. You can and probably should split your mineral deposit samples into train and test sets so that you can empirically estimate how accurate your Kriging model is.

Check our paper, which is all about generalization error in geospatial settings: https://www.frontiersin.org/articles/10.3389/fams.2021.689393/full

We are in the process of porting all these validation methods to be compatible with TableTransforms.jl pipelines. They are already implemented in GeoStats.jl as you can see here: https://juliaearth.github.io/GeoStatsDocs/stable/validation.html

Also watch our JuliaCon talk for more examples: https://www.youtube.com/watch?v=75A6zyn5pIE The things that we can do are already much more sophisticated than what sk-learn or mlj can do because they are not flexible enough to handle geospatial domains, efficient (lazy) partitioning schemes, etc.