bethatkinson / rpart

Recursive Partitioning and Regression Trees
43 stars 23 forks source link

Add more arguments to plot.rpart() #19

Closed bgreenwell closed 2 years ago

bgreenwell commented 4 years ago

@bethatkinson would you be open to a PR that adds a few extra args to plot.rpart()? In particular, I was hoping to add three new argument to control the branch color, branch linetype, and branch width. This would be a useful feature in some cases (e.g., highlighting specific tree diagrams in a plot, illustrating the pruning sequence, etc.). The changes needed are minimal and the modified function is given below:

plot.rpart <- function(x, uniform = TRUE, branch = 1, compress = TRUE,
                      nspace, margin = 0, minbranch = 0.3, branch.col = 1,
                      branch.lty = 1, branch.lwd = 1, ...) {
  if (!inherits(x, "rpart")) stop("Not a legitimate \"rpart\" object")
  if (nrow(x$frame) <= 1L) stop("fit is not a tree, just a root")

  if (compress & missing(nspace)) nspace <- branch
  if (!compress) nspace <- -1L     # means no compression
  ## if (dev.cur() == 1L) dev.new() # not needed in R

  parms <- list(uniform = uniform, branch = branch, nspace = nspace,
                minbranch = minbranch)

  ## define the plot region
  temp <- rpart:::rpartco(x, parms)
  xx <- temp$x
  yy <- temp$y
  temp1 <- range(xx) + diff(range(xx)) * c(-margin, margin)
  temp2 <- range(yy) + diff(range(yy)) * c(-margin, margin)
  plot(temp1, temp2, type = "n", axes = FALSE, xlab = "", ylab = "", ...)
  ## Save information per device, once a new device is opened.
  assign(paste0("device", dev.cur()), parms, envir = rpart:::rpart_env)

  # Draw a series of horseshoes or V's, left son, up, down to right son
  #   NA's in the vector cause lines() to "lift the pen"
  node <- as.numeric(row.names(x$frame))
  temp <- rpart:::rpart.branch(xx, yy, node, branch)

  if (branch > 0) text(xx[1L], yy[1L], "|")
  lines(c(temp$x), c(temp$y), col = branch.col, lty = branch.lty, 
        lwd = branch.lwd)
  invisible(list(x = xx, y = yy))
}

Example

library(rpart)
tree <- rpart(mpg ~ ., data = mtcars)
plot(tree, branch.lty = 2, branch.col = 2, branch.lwd = 2)
text(tree)

image

bethatkinson commented 4 years ago

Sure - that makes sense to me

Beth


From: Brandon Greenwell notifications@github.com Sent: Thursday, May 28, 2020 3:51 PM To: bethatkinson/rpart rpart@noreply.github.com Cc: Atkinson, Elizabeth J. (Beth), M.S. atkinson@mayo.edu; Mention mention@noreply.github.com Subject: [EXTERNAL] [bethatkinson/rpart] Add more arguments to plot.rpart() (#19)

@bethatkinsonhttps://github.com/bethatkinson would you be open to a PR that adds a few extra args to plot.rpart()? In particular, I was hoping to add three new argument to control the branch color, branch linetype, and branch width. This would be a useful feature in some cases (e.g., highlighting specific tree diagrams in a plot, illustrating the pruning sequence, etc.). The changes needed are minimal and the modified function is given below:

plot.rpart <- function(x, uniform = TRUE, branch = 1, compress = TRUE, nspace, margin = 0, minbranch = 0.3, branch.col = 1, branch.lty = 1, branch.lwd = 1, ...) { if (!inherits(x, "rpart")) stop("Not a legitimate \"rpart\" object") if (nrow(x$frame) <= 1L) stop("fit is not a tree, just a root")

if (compress & missing(nspace)) nspace <- branch if (!compress) nspace <- -1L # means no compression

if (dev.cur() == 1L) dev.new() # not needed in R

parms <- list(uniform = uniform, branch = branch, nspace = nspace, minbranch = minbranch)

define the plot region

temp <- rpart:::rpartco(x, parms) xx <- temp$x yy <- temp$y temp1 <- range(xx) + diff(range(xx)) c(-margin, margin) temp2 <- range(yy) + diff(range(yy)) c(-margin, margin) plot(temp1, temp2, type = "n", axes = FALSE, xlab = "", ylab = "", ...)

Save information per device, once a new device is opened.

assign(paste0("device", dev.cur()), parms, envir = rpart:::rpart_env)

Draw a series of horseshoes or V's, left son, up, down to right son

NA's in the vector cause lines() to "lift the pen"

node <- as.numeric(row.names(x$frame)) temp <- rpart:::rpart.branch(xx, yy, node, branch)

if (branch > 0) text(xx[1L], yy[1L], "|") lines(c(temp$x), c(temp$y), col = branch.col, lty = branch.lty, lwd = branch.lwd) invisible(list(x = xx, y = yy)) }

Example

library(rpart) tree <- rpart(mpg ~ ., data = mtcars) plot_tree(tree, branch.lty = 2, branch.col = 2, branch.lwd = 2) text(tree)

[image]https://user-images.githubusercontent.com/3777755/83192269-49010980-a103-11ea-92ab-3df3be9d18e0.png

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHubhttps://github.com/bethatkinson/rpart/issues/19, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ACWQG5YDIZ4DWRKSYMWDORDRT3FETANCNFSM4NNNKSHA.

bgreenwell commented 4 years ago

Alright, I'll try to pull together a PR in the next week or so!

bgreenwell commented 4 years ago

PR here: https://github.com/bethatkinson/rpart/pull/20.