mlr-org / mlr3pipelines

Dataflow Programming for Machine Learning in R
https://mlr3pipelines.mlr-org.com/
GNU Lesser General Public License v3.0
137 stars 25 forks source link

New predictions with graph/ensemble #556

Closed MaximilianPi closed 3 years ago

MaximilianPi commented 3 years ago

Hi mlr3 team,

once again, it's a great package! I have a question about predictions with a graph object:

I combine several models via the union operator into one graph to create an ensemble model (I hope this is the correct way? Or are there alternatives to create an ensemble model?) and after training I want to predict on new data that is unknown at the time of the fitting/training step (otherwise I could use a PipeOpLearnerCV). But it seems that this is currently not supported as there is no 'newdata' argument in the predict method?

Here's a minimal example:

library(mlr3)
library(mlr3learners)
library(mlr3pipelines)

rf = mlr3pipelines::PipeOpLearner$new(mlr3::lrn("classif.ranger"), id = "rf")
knn = mlr3pipelines::PipeOpLearner$new(mlr3::lrn("classif.kknn"), id = "kknn")

avg = mlr3pipelines::po("classifavg", innum = 2)
ensemble = mlr3pipelines::gunion(list(rf, knn)) %>>% avg

task = mlr3::TaskClassif$new(id = "iris", backend = iris, target = "Species")

ensemble$train(task)

ensemble$predict(newdata = iris)  # ?
sumny commented 3 years ago

Hi @MaximilianPi ,

thanks for the question! In general, you may want to wrap your Graph in a GraphLearner to gain all the functionality a "normal" Learner provides (see also https://mlr3book.mlr-org.com/pipe-modeling.html) Also, you could make the construction of your ensemble slightly easier by doing:

rf = lrn("classif.ranger", id = "rf")
knn = lrn("classif.kknn", id = "kknn")
ensemble = list(rf, knn) %>>% po("classifavg", innum = 2)

(i.e., mlr3pipelines automatically most of the time "knows" when to apply gunion and when to coerce Learners to PipeOpLearners; but technically your code is perfectly fine)

In the future me way want to add a ppl_ensemble function (?ppl) that allows for an even easier ensemble construction, but currently the way you do it is the way to go.

Now for the GraphLearner part:

ensemble_gl = GraphLearner$new(ensemble)

This GraphLearner now has all the functionality a normal Learner has, including predict_newdata, see ?Learner

ensemble_gl$train(task)
ensemble_gl$predict_newdata(iris)

Please let me know if you found this helpful!

MaximilianPi commented 3 years ago

Ah perfect, many thanks for the quick response.