BelindaHernandez / bartBMA

4 stars 3 forks source link

get_subset function and grow_tree #5

Closed EoghanONeill closed 5 years ago

EoghanONeill commented 5 years ago

The get_subset function returns the first grow_obs.size() rows of the input data. However, it should probably return the rows indexed by grow_obs?

This is important for lines 946-952 of BARTBMA_SumTreeLikelihood.cpp, so that the function correctly continues to the next changepoint if the current potential changepoint would send too few observations left or right.


library(devtools)
install_github("EoghanONeill/bartBMA")
library(bartBMA)

#simulate data
N <- 1000
set.seed(100)
x1 <- runif(N)
x2 <- runif(N)
x3 <- runif(N)
x4 <- runif(N)
x5 <- runif(N)
x6 <- runif(N)
x7 <- runif(N)
x8 <- runif(N)
x9 <- runif(N)
x10 <- runif(N)
epsilon <- rnorm(N)

xcov <- cbind(x1,x2,x3,x4,x5,x6,x7,x8,x9,x10)
y <- sin(pi*x1*x2) + 20*(x3-0.5)^2+10*x4+5*x5+epsilon

bart_bma_example <- bartBMA(x.train = xcov, y.train = y)

tree_mat_issue <- bart_bma_example[[3]][[1]][[1]] #a tree matrix

# then check if the correct rows are obtained by get_subset
# It can be seen in the C++ code that get_subset only obtains
# the first grow_obs.size() rows of the data matrix.
# Therefore the function needs to be changed to be more similar to 
# get_grow_obs, but obtaining the relevant rows from the whole matrix, rather than from one column
grow_obs_issue <- find_term_obs(tree_mat_issue,11)
data_curr_node_example <- get_subset(xcov, grow_obs_issue)

#the last element of get_grow_obs is 995. This refers to the 996^th row.
#but data_curr_node_example is just the first length(grow_obs_issue) rows of xcov
all(data_curr_node_example == xcov[1:147,])

#This is important for defining curr_cols2 and throwing an error on lines 932 (this line is ok), and 
# continuing to the next loop on lines 946-952 of BARTBMA_SumTreeLikelihood.cpp
#could just use get_grow_obs instead
# get_subset is also used to define the input for grow_tree, data_curr_node,
#in lines 954 and 1138
#but this input is not used in the grow_tree function, and
#therefore the grow_tree function can be edited to remove this on line 470 (and lines 954 and 1138).
EoghanONeill commented 5 years ago

The easiest solution is to remove the get_subset() function and replace the following

data_curr_node=get_subset(data,wrap(grow_obs))

with

data_curr_node=data.rows(grow_obs)
BelindaHernandez commented 5 years ago

Well spotted!

yes this was a typo in the get_subset function it was supposed to get the grow_obs[i] th row of the dataset but was only getting the ith row.

Your solution also works so we'll go with that!.