Closed grantmcdermott closed 3 months ago
What do you exactly mean by retrieving the data, the learners generally don't store the task (input data), hence you provide it. But you can always access the stored model like you've done and have the orginal data from the input task. But, perhaps I don't quite undertand the use-case.
library(mlr3verse)
#> Loading required package: mlr3
task = tsk("iris")
learner = lrn("classif.rpart")
learner$train(task)
learner$model
#> n= 150
#>
#> node), split, n, loss, yval, (yprob)
#> * denotes terminal node
#>
#> 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)
#> 2) Petal.Length< 2.45 50 0 setosa (1.00000000 0.00000000 0.00000000) *
#> 3) Petal.Length>=2.45 100 50 versicolor (0.00000000 0.50000000 0.50000000)
#> 6) Petal.Width< 1.75 54 5 versicolor (0.00000000 0.90740741 0.09259259) *
#> 7) Petal.Width>=1.75 46 1 virginica (0.00000000 0.02173913 0.97826087) *
Created on 2024-07-22 with reprex v2.1.1
Thanks for the reply @m-muecke. Let me try to add some more context.
I have written a package called parttree that produces 2-D plots of simple decision tree partitions. (Where "simple" means contains no more than 2 explanatory features.) It's not so important, but the original motivation for the package was to help my students visualize how a tree was carving up the data space.
While I originally wrote the package with ggplot2 in mind, in the latest development version I'm adding a base plot.parttree
method (that leverages tinyplot under the hood to create a legend etc.).
Here's a simple example of the package in action:
pkgload::load_all("~/Documents/Projects/parttree")
library(rpart)
fit = rpart(Kyphosis ~ Start + Age, data = kyphosis)
pt = parttree(fit)
plot(pt, pch = 19, palette = "dark")
What's happening under the hood is that parttree(tree)
extracts that partition nodes and coerces them in a simple dataframe, containing the coordinates of the partition rectangles.
pt
#> node Kyphosis path xmin xmax ymin ymax
#> 1 3 present Start < 8.5 -Inf 8.5 -Inf Inf
#> 2 4 absent Start >= 8.5 --> Start >= 14.5 14.5 Inf -Inf Inf
#> 3 10 absent Start >= 8.5 --> Start < 14.5 --> Age < 55 8.5 14.5 -Inf 55
#> 4 22 absent Start >= 8.5 --> Start < 14.5 --> Age >= 55 --> Age >= 111 8.5 14.5 111 Inf
#> 5 23 present Start >= 8.5 --> Start < 14.5 --> Age >= 55 --> Age < 111 8.5 14.5 55 111
Importantly—and this is the key part for my current issue here—I also need to store some information about (the extent of) the original data. Why? Well, because otherwise I won't know the limits of the "outer" rectangles at plot time. So the Inf
values in the dataframe above get replaced by the relevant values of xrange
and yrange
below.
attributes(pt)[["parttree"]]
#> $xvar
#> [1] "Start"
#>
#> $yvar
#> [1] "Age"
#>
#> $xrange
#> [1] 1 18
#>
#> $yrange
#> [1] 1 206
#>
#> $response
#> [1] "Kyphosis"
#>
#> $call
#> rpart(formula = Kyphosis ~ Start + Age, data = kyphosis)
#>
#> $na.action
#> NULL
#>
#> $raw_data
#> NULL
So, I need the original data in order to be able to calculate the range of data.
For mlr3 the workflow would look something like:
library(mlr3)
mytask = tsk("iris")
learner = lrn("classif.rpart")
learner$train(mytask)
pt2 = parttree(learner)
plot(pt2)
Because the user is passing the learner object to parttree()
, this is where the latter function needs to be able to retrieve the original data.
P.S. You might have noticed that in plot(pt)
we also get the original data added in as points. That's not essential, but is another reason why I'd like to be able to retrieve the original data.
Since the learner doesn't store the task, I see two options:
learner$model
, this is what some mlr3viz::autoplot()
functions are doing. See the rpart learner as an example: https://github.com/mlr-org/mlr3viz/blob/main/R/LearnerClassifRpart.R#L52.Since not every model keeps the data in the model object and since each would've to be handled differently, I would go for approach 1.
Thanks, that's helpful.
I think we can close this now, but one quick question first: Is there a formal way to check whether keep_model = TRUE
was passed (to a task) from a learner?
I'm not sure if there is a formal way, but generally you can retrieve the values from the param set and then check if the value was passed, i.e. not null and set to true as follows:
library(mlr3)
task = tsk("iris")
learner = lrn("classif.rpart")
learner$param_set
#> <ParamSet(10)>
#> id class lower upper nlevels default value
#> <char> <char> <num> <num> <num> <list> <list>
#> 1: cp ParamDbl 0 1 Inf 0.01
#> 2: keep_model ParamLgl NA NA 2 FALSE
#> 3: maxcompete ParamInt 0 Inf Inf 4
#> 4: maxdepth ParamInt 1 30 30 30
#> 5: maxsurrogate ParamInt 0 Inf Inf 5
#> 6: minbucket ParamInt 1 Inf Inf <NoDefault[0]>
#> 7: minsplit ParamInt 1 Inf Inf 20
#> 8: surrogatestyle ParamInt 0 1 2 0
#> 9: usesurrogate ParamInt 0 2 3 2
#> 10: xval ParamInt 0 Inf Inf 10 0
# check if it was passed
pars = learner$param_set$get_values()
isTRUE(pars$keep_model)
#> [1] FALSE
learner = lrn("classif.rpart", keep_model = TRUE)
learner$param_set
#> <ParamSet(10)>
#> id class lower upper nlevels default value
#> <char> <char> <num> <num> <num> <list> <list>
#> 1: cp ParamDbl 0 1 Inf 0.01
#> 2: keep_model ParamLgl NA NA 2 FALSE TRUE
#> 3: maxcompete ParamInt 0 Inf Inf 4
#> 4: maxdepth ParamInt 1 30 30 30
#> 5: maxsurrogate ParamInt 0 Inf Inf 5
#> 6: minbucket ParamInt 1 Inf Inf <NoDefault[0]>
#> 7: minsplit ParamInt 1 Inf Inf 20
#> 8: surrogatestyle ParamInt 0 1 2 0
#> 9: usesurrogate ParamInt 0 2 3 2
#> 10: xval ParamInt 0 Inf Inf 10 0
pars = learner$param_set$get_values()
isTRUE(pars$keep_model)
#> [1] TRUE
Created on 2024-07-24 with reprex v2.1.1
Perfect. Thanks for all your help @m-muecke. Much appreciated!
Retrieving data from a task is easily done with the
data()
method:However, I'm looking find an equivalent method that will work on a learner:
Created on 2024-07-18 with reprex v2.1.0
Any suggestions would be welcome!
Context: I am developing a simple package that supports 2-D plotting of decision trees partitions and need the original data to establish the extent of the plot window. I have been able to find appropriate data retrieval methods for other frontends like partykit and tidymodels, but am struggling for mlr3.