martin-borkovec / ggparty

147 stars 14 forks source link

Add mean value of a class for a regression tree #44

Closed auroreRECE closed 3 years ago

auroreRECE commented 3 years ago

Hello

I am trying to add the mean value of a class in the terminal node. Here the example :

library(rpart)
library(ggparty)
data("iris")

result_tree <- rpart(Petal.Length ~  Sepal.Width + Sepal.Length, data = iris)

py2 <- as.party(result_tree)

arb = ggparty(py2) +
  geom_edge(size = 0.6, color = 'grey60', linetype = 2) +
  geom_edge_label(aes(label = prettyNum(breaks_label,  digits = 1)), size = 4) +
  geom_node_label(aes(label = splitvar), ids = "inner") +
  geom_node_plot(gglist = list(geom_boxplot(aes(x = '', y = Petal.Length, 
                                                fill = factor(id))),
                               theme_classic(), ylab(""),
                               theme(axis.text.x = element_blank(),
                                     axis.title.x = element_blank(),
                                     axis.ticks.x = element_blank()),
                               theme(legend.position = "none")),
                 shared_legend = FALSE) +
  geom_node_label(aes(label = paste0("Node =", id)),
                  fontface = "bold",
                  ids = "terminal",
                  size = 3,
                  nudge_y = 0.01)
arb

Giving this plot :

example tree 2

and the outpout of the tree with highlighted the mean values of each terminal nodes :

example tree

I would like to have in the geom_node_label, the mean value of each class (for example '1.530' at the place where 'Node 3' is written). Is it possible ? I don't find in ggparty(py)$data where I can access the mean value for each terminal node.

Thank you for your help and for this amazing package,

Aurore

martin-borkovec commented 3 years ago

I'm not sure, but i think the mean is calculated each time py2 is printed into the console, so it's not stored in the object. At least I also don't know where it is....

So you have to calculate it for each node. You can do this with ggparty's add_vars argument as follows and then access it with geom_node_label. Hope that helps :)

library(rpart)
library(ggparty)
#> Loading required package: ggplot2
#> Loading required package: partykit
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm
data("iris")

result_tree <- rpart(Petal.Length ~  Sepal.Width + Sepal.Length, data = iris)

py2 <- as.party(result_tree)

ggparty(py2,
        add_vars = list(mean_petal_length = function(data, node) round(mean(node$data$Petal.Length), 2))) +
  geom_edge(size = 0.6, color = 'grey60', linetype = 2) +
  geom_edge_label(aes(label = prettyNum(breaks_label,  digits = 1)), size = 4) +
  geom_node_label(aes(label = splitvar), ids = "inner") +
  geom_node_plot(gglist = list(geom_boxplot(aes(x = '', y = Petal.Length, 
                                                fill = factor(id))),
                               theme_classic(), ylab(""),
                               theme(axis.text.x = element_blank(),
                                     axis.title.x = element_blank(),
                                     axis.ticks.x = element_blank()),
                               theme(legend.position = "none")),
                 shared_legend = FALSE) +
  geom_node_label(aes(label = paste0("Mean = ", mean_petal_length)),
                  fontface = "bold",
                  ids = "terminal",
                  size = 3,
                  nudge_y = 0.01)

Created on 2021-01-14 by the reprex package (v0.3.0)