martin-borkovec / ggparty

147 stars 14 forks source link

Using a different plot_call in geom_node_plot #36

Closed HeidiSeibold closed 5 years ago

HeidiSeibold commented 5 years ago

I'd like to use a custom function in geom_node_plot which takes the model as argument. Is this possible? If so, can you help me with the example below? Thanks in advance :sunflower:

library("ggparty")
#> Loading required package: ggplot2
#> Loading required package: partykit
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm
library("survival")
library("model4you")
data(GBSG2, package = "TH.data")

## model
bmod <- survreg(Surv(time, cens) ~ horTh, data = GBSG2, model = TRUE)
survreg_plot(bmod)


## tree
tree <- pmtree(bmod)
#> No data given. I'm using data set GBSG2 from the current environment parent.frame(). Please check if that is what you want.
plot(tree, terminal_panel = node_pmterminal(tree, plotfun = survreg_plot))


## plot survival curves in ggparty
# get model from node
get_model <- function(data, node) {
  list(node$info$model)
}

# plot basis
p <- ggparty(tree, add_vars = list(mdl = get_model)) 
names(p$data)
#>  [1] "id"                    "x"                    
#>  [3] "y"                     "parent"               
#>  [5] "birth_order"           "breaks_label"         
#>  [7] "info"                  "info_list"            
#>  [9] "splitvar"              "level"                
#> [11] "kids"                  "nodesize"             
#> [13] "p.value"               "horizontal"           
#> [15] "x_parent"              "y_parent"             
#> [17] "nodedata_time"         "nodedata_cens"        
#> [19] "nodedata_horTh"        "nodedata_age"         
#> [21] "nodedata_menostat"     "nodedata_tsize"       
#> [23] "nodedata_tgrade"       "nodedata_pnodes"      
#> [25] "nodedata_progrec"      "nodedata_estrec"      
#> [27] "nodedata_fitted_nodes" "mdl"

# add edges and inner nodes
p1 <- p + 
  geom_edge() +
  geom_edge_label() +
  geom_node_label(aes(label = splitvar),
    ids = "inner")

# add terminal node plots
p1 +
  geom_node_plot(plot_call = survreg_plot, gglist = list())
#> Error in (function (mod, data = NULL, theme = theme_classic(), yrange = NULL) : argument "mod" is missing, with no default

Created on 2019-06-20 by the reprex package (v0.2.1)

martin-borkovec commented 5 years ago

Thanks for using ggparty and helping finding issues!

So, this would seem like a case for which ggparty's argument plot_call exists. Sadly I think it's not really possible to get it to work here. Since geom_node_plot() works with the help of facet_wrap(), the data entry of the ggplot object needs to contain the node id. As far as I can tell this is not the case with the objects created by survreg_plot() even when you use the data argument. It would be necessary to change the mapping of survreg_plot() to include the id.

library("survival")
library("model4you")
#> Loading required package: partykit
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm
library("ggparty")
#> Loading required package: ggplot2
data(GBSG2, package = "TH.data")

## model
bmod <- survreg(Surv(time, cens) ~ horTh, data = GBSG2, model = TRUE)
surv_gg <- survreg_plot(bmod, data = cbind(GBSG2, id = 1))
## no id column
head(surv_gg$data)
#>   horTh        pr probability
#> 1    no  56.22822        0.99
#> 2    no 133.23025        0.97
#> 3    no 199.84924        0.95
#> 4    no 261.78897        0.93
#> 5    no 320.98908        0.91
#> 6    no 378.42847        0.89

Created on 2019-06-20 by the reprex package (v0.3.0)

So this seems to be a a problem of this approach that's not so easy to work around.

Nevertheless with the current version of geom_node_plot it's not possible to specify futher arguments for the new plot_call. It only provides the data argument for any plot_call, that's why you receive the error "argument "mod" is missing, with no default". This could be easily changed e.g. by creating a new argument where you can specify the plot_call arguments. But I'm not sure if I should add yet another argument for a feature with not that many use cases I know about. How about you? Have you used the plot_call argument succesfully with any other ggplot-based function?

However, to get back to your problem: So although this seems like we could use plot_call I don't think it can be done in this case. But what we could do is supply the data directly via the gglist argument after creating a dataframe with the plot data for each node. This is not a very elegant solution, but should be applicable in most cases.

tree <- pmtree(bmod)
#> No data given. I'm using data set GBSG2 from the current environment parent.frame(). Please check if that is what you want.

#get data for geom_node_plot's gglist 
survplot_data <- vector("list", length(tree))
for (i in seq_along(tree)) {
    gg <- survreg_plot(tree[i]$info$model)
    survplot_data[[i]] <- cbind(gg$data, id = i)
}
survplot_data <- do.call(rbind, survplot_data)

# plot basis
p <- ggparty(tree) + 
  geom_edge() +
  geom_edge_label() +
  geom_node_label(aes(label = splitvar),
                  ids = "inner")

p + 
  geom_node_plot(gglist = list(geom_line(data = survplot_data,
                                         mapping = aes(x = pr,
                                                       y = probability,
                                                       colour = horTh)
  )))

Created on 2019-06-20 by the reprex package (v0.3.0)

So this is basically how it could work. As you can see, it doesn't really. This is due to tree[i]$info$model being identical in every node. I think the actual model we want to plot is stored in tree[i]$node$info$object. However, passing this to survreg_plot produces the following error. But maybe you know how fix this?

survreg_plot(tree[1]$node$info$object)
#> Error in eval(modcall$data): object 'di' not found

Created on 2019-06-20 by the reprex package (v0.3.0)

HeidiSeibold commented 5 years ago

Wow thanks for the super quick and useful help!

Got it now:

library("survival")
library("model4you")
library("ggparty")
library("plyr")
data(GBSG2, package = "TH.data")

## model
bmod <- survreg(Surv(time, cens) ~ horTh, data = GBSG2, model = TRUE)
tree <- pmtree(bmod)

# get data for geom_node_plot's gglist 
obs_nodes <- predict(tree, type = "node")
get_plot_data <- function(i, data) {
    dat <- subset(data, obs_nodes == i)
    imod <- update(bmod, data = dat)
    gg <- survreg_plot(imod, data = dat)
    cbind(gg$data, id = i)
}
survplot_data <- ldply(unique(obs_nodes), .fun = get_plot_data, data = GBSG2)

# plot 
p <- ggparty(tree) + 
    geom_edge() +
    geom_edge_label() +
    geom_node_label(aes(label = splitvar),
        ids = "inner")

p + 
    geom_node_plot(gglist = list(geom_line(data = survplot_data,
        mapping = aes(x = pr,
            y = probability,
            colour = horTh)
    ), xlim(0, 2900), xlab("time"),
        theme_classic()))

Created on 2019-06-21 by the reprex package (v0.2.1)

Great work! Looks beautiful.

martin-borkovec commented 5 years ago

great, I'm glad you worked it out!