Closed auroreRECE closed 3 years ago
I'm not sure, but i think the mean is calculated each time py2 is printed into the console, so it's not stored in the object. At least I also don't know where it is....
So you have to calculate it for each node. You can do this with ggparty's add_vars argument as follows and then access it with geom_node_label. Hope that helps :)
library(rpart)
library(ggparty)
#> Loading required package: ggplot2
#> Loading required package: partykit
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm
data("iris")
result_tree <- rpart(Petal.Length ~ Sepal.Width + Sepal.Length, data = iris)
py2 <- as.party(result_tree)
ggparty(py2,
add_vars = list(mean_petal_length = function(data, node) round(mean(node$data$Petal.Length), 2))) +
geom_edge(size = 0.6, color = 'grey60', linetype = 2) +
geom_edge_label(aes(label = prettyNum(breaks_label, digits = 1)), size = 4) +
geom_node_label(aes(label = splitvar), ids = "inner") +
geom_node_plot(gglist = list(geom_boxplot(aes(x = '', y = Petal.Length,
fill = factor(id))),
theme_classic(), ylab(""),
theme(axis.text.x = element_blank(),
axis.title.x = element_blank(),
axis.ticks.x = element_blank()),
theme(legend.position = "none")),
shared_legend = FALSE) +
geom_node_label(aes(label = paste0("Mean = ", mean_petal_length)),
fontface = "bold",
ids = "terminal",
size = 3,
nudge_y = 0.01)
Created on 2021-01-14 by the reprex package (v0.3.0)
Hello
I am trying to add the mean value of a class in the terminal node. Here the example :
Giving this plot :
and the outpout of the tree with highlighted the mean values of each terminal nodes :
I would like to have in the geom_node_label, the mean value of each class (for example '1.530' at the place where 'Node 3' is written). Is it possible ? I don't find in ggparty(py)$data where I can access the mean value for each terminal node.
Thank you for your help and for this amazing package,
Aurore