grf-labs / policytree

Policy learning via doubly robust empirical welfare maximization over trees
https://grf-labs.github.io/policytree/
MIT License
74 stars 15 forks source link

Return R-indexed node.id #82

Closed erikcs closed 3 years ago

erikcs commented 3 years ago

The tree is printed with R-indexed node numbers, so the prediction returns should be 1-indexed too.

(DiagrammeR's hover-over node id prints 0-indexed in plots. We stick with R-indexing for consistency).

> n <- 25
> p <- 5
> d <- 3
> X <- matrix(runif(n * p), n, p)
> Y <- matrix(rnorm(n * d), n, d)
> tree <- policy_tree(X, Y, depth = 2)
> tree
policy_tree object 
Tree depth:  2 
Actions:  1 2 3 
Variable splits: 
(1) split_variable: X3  split_value: 0.832141 
  (2) split_variable: X3  split_value: 0.507951 
    (4) * action: 2 
    (5) * action: 1 
  (3) split_variable: X3  split_value: 0.868828 
    (6) * action: 2 
    (7) * action: 3 
> predict(tree, X, type = "node.id")
 [1] 4 4 6 7 5 7 5 5 5 6 6 5 7 4 4 5 6 7 4 4 5 4 5 4 4