martin-borkovec / ggparty

147 stars 14 forks source link

Terminal node position #41

Closed ghost closed 3 years ago

ghost commented 4 years ago

Hi!!

I'm using ggparty to plot a binary tree and I find all the terminal nodes are plotted at the bottom. I wonder if it would be possible plot at the same level nodes 4-5, 6-7, 8-9, 12-13, 14-15, etc. automatically (I look for a different approach to write the coordinates manually).

Thanks in advance.

tree

martin-borkovec commented 4 years ago

unfortunately no, not automatically. (but also something that would be worth implementing some day...)

but you can use the internal get_plot_data function to save yourself from defining all the coordinates manually and only adjusting the y coordinates of the terminal nodes in question like this:

library(ggparty)
#> Warning: package 'ggparty' was built under R version 3.6.3
#> Loading required package: ggplot2
#> Loading required package: partykit
#> Warning: package 'partykit' was built under R version 3.6.3
#> Loading required package: grid
#> Loading required package: libcoin
#> Warning: package 'libcoin' was built under R version 3.6.3
#> Loading required package: mvtnorm
data("WeatherPlay", package = "partykit")
sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75)
sp_w <- partysplit(4L, index = 1:2)
pn <- partynode(1L, split = sp_o, kids = list(
  partynode(2L, split = sp_h, kids = list(
    partynode(3L, info = "yes"),
    partynode(4L, info = "no"))),
  partynode(5L, info = "yes"),
  partynode(6L, split = sp_w, kids = list(
    partynode(7L, info = "yes"),
    partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)

# get layout data.frame
layout <- ggparty:::get_plot_data(py)[, 1:3]

# adjust y coordinate of terminal node in question
layout$y[5] <- layout$y[2]

# plot with adjustes layout
ggparty(py, layout = layout) +
  geom_edge() +
  geom_edge_label() +
  geom_node_label(aes(label = splitvar),
                  ids = "inner") +
  geom_node_label(aes(label = info),
                  ids = "terminal")

Created on 2020-07-27 by the reprex package (v0.3.0)

ghost commented 4 years ago

I really appreciate your help. I have done a for loop (surely not optimal) to raise the terminal nodes.

  layout <- ggparty:::get_plot_data(py)[, 1:4]

  for(i in 2:(nrow(layout)-1)) {

    for(j in (i+1):nrow(layout)) {

      if(layout$parent[i] == layout$parent[j] && layout$y[i] > layout$y[j]) {

        layout$y[j] <- layout$y[i]

        break

      } else if(layout$parent[i] == layout$parent[j] && layout$y[j] > layout$y[i]) {

        layout$y[i] <- layout$y[j]

        break

      }
    }
  }

  ggparty(py, horizontal = horizontal,
          add_vars = list(R = "$node$info$R",
                          Y = "$node$info$Y",
                          Id = "$node$info$Id",
                          n = "$node$info$n",
                          PR = "$node$info$PR"),
          layout = layout[, 1:3]) + [...]

tree2