Closed grantmcdermott closed 4 months ago
Thanks for the issue, @grantmcdermott!
Looks like rpart carries around the data it's fitted on as an attribute to its terms object:
library(parsnip)
library(rpart)
tree_parsnip <-
decision_tree() |>
set_engine("rpart") |>
set_mode("classification") |>
fit(Species ~ Petal.Length + Petal.Width, data = iris)
dim(attr(tree_parsnip$fit$terms, ".Environment")$data)
#> [1] 150 5
head(attr(tree_parsnip$fit$terms, ".Environment")$data)
#> Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#> 1 5.1 3.5 1.4 0.2 setosa
#> 2 4.9 3.0 1.4 0.2 setosa
#> 3 4.7 3.2 1.3 0.2 setosa
#> 4 4.6 3.1 1.5 0.2 setosa
#> 5 5.0 3.6 1.4 0.2 setosa
#> 6 5.4 3.9 1.7 0.4 setosa
Created on 2024-07-15 with reprex v2.1.0
If you're more intrigued by the on-label rpart::rpart(model)
argument, which allows for explicitly opting to keep the model frame around, you can pass it as an "engine argument" by passing it in the usual form to set_engine()
:
library(parsnip)
library(rpart)
tree_parsnip <-
decision_tree() |>
set_engine("rpart", model = TRUE) |>
set_mode("classification") |>
fit(Species ~ Petal.Length + Petal.Width, data = iris)
dim(tree_parsnip$fit$model)
#> [1] 150 3
head(tree_parsnip$fit$model)
#> Species Petal.Length Petal.Width
#> 1 setosa 1.4 0.2
#> 2 setosa 1.4 0.2
#> 3 setosa 1.3 0.2
#> 4 setosa 1.5 0.2
#> 5 setosa 1.4 0.2
#> 6 setosa 1.7 0.4
Note that the update()
methods have methods for the model specifications rather than fitted models (e.g. decision_tree()
rather than fit(decision_tree())
), so there's no training data passed around there.
Ah, super. Thanks @simonpcouch! Much appreciated, bud.
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.
The problem
Hi folks. I'm having trouble retrieving the original data object that was passed to the "rpart" engine. The old
eval(tree$call$data)
approach doesn't work here b/c the data object is hidden... and so it reverts back to the globalutils::data()
function.I tried to see how you implement this as part of the
update
method for rpart backends. But unless I'm missing something, an_rpart.update
method is not supported either.Reproducible example
Created on 2024-07-15 with reprex v2.1.0
Any suggestions for a workaround would be much appreciated. Thanks in advance.
P.S. Possibly related to https://github.com/tidymodels/parsnip/issues/257 but I need an automated way to retrieve the original data object. In this issue, the solution was to manually write it back to the call. Similarly,
repair_call
also needs the user to manually specify the data object.