JingyuHe / XBART

88 stars 25 forks source link

Prediction is zero which is incorrect #100

Open rsparapa opened 4 years ago

rsparapa commented 4 years ago

Hi Jingyu:

From this program, all posterior samples are zero including sigma.


library(XBART)

f = function(x)
    10*sin(pi*x[ , 1]*x[ , 2]) + 5*x[ , 3]*x[ , 4]^2 + 20*x[ , 5]

N = 10000
sigma = 1.0 ##y = f(x) + sigma*z where z~N(0, 1)
P = 25       ##number of covariates
B=8

V = diag(P)
V[5, 6] = 0.8
V[6, 5] = 0.8
L <- chol(V)
set.seed(12)
x.train=matrix(rnorm(N*P), N, P) %*% L
dimnames(x.train)[[2]] <- paste0('x', 1:P)
y.train=(f(x.train)+sigma*rnorm(N))

H=20
x=seq(-3, 3, length.out=H+1)[-(H+1)]
x.test=matrix(0, nrow=H, ncol=P)
x.test[ , 5]=x

##(L=0.25*(log(N)^(log(log(N)))))

post = XBART.CLT(cbind(y.train), x.train, x.test,
                 num_trees=50, num_sweeps=40,
                 burnin=15)
post$yhat.test=post$yhats_test
##post$yhats_test=NULL
post$yhat.test.mean=apply(post$yhat.test, 1, mean)
post$yhat.test.025=apply(post$yhat.test, 1, quantile, probs=0.025)
post$yhat.test.975=apply(post$yhat.test, 1, quantile, probs=0.975)

plot(x, f(x.test), col='blue', type='l', ylab='f(x)')
lines(x, post$yhat.test.mean)
dev.copy2pdf(file='bigdata.pdf')    

The R output is as follows.


R version 3.5.2 (2018-12-20) -- "Eggshell Igloo"
Copyright (C) 2018 The R Foundation for Statistical Computing
Platform: x86_64-pc-linux-gnu (64-bit)

R is free software and comes with ABSOLUTELY NO WARRANTY.
You are welcome to redistribute it under certain conditions.
Type 'license()' or 'licence()' for distribution details.

  Natural language support but running in an English locale

R is a collaborative project with many contributors.
Type 'contributors()' for more information and
'citation()' on how to cite R or R packages in publications.

Type 'demo()' for some demos, 'help()' for on-line help, or
'help.start()' for an HTML browser interface to help.
Type 'q()' to quit R.

> setwd('/home/rsparapa/git/XBART/demo')
options(width=78, length=99999)
> library(XBART)
> f = function(x)
+     10*sin(pi*x[ , 1]*x[ , 2]) + 5*x[ , 3]*x[ , 4]^2 + 20*x[ , 5]
> N = 10000
> sigma = 1.0 ##y = f(x) + sigma*z where z~N(0, 1)
> P = 25       ##number of covariates
> B=8
> V = diag(P)
> V[5, 6] = 0.8
> V[6, 5] = 0.8
> L <- chol(V)
> set.seed(12)
> x.train=matrix(rnorm(N*P), N, P) %*% L
> dimnames(x.train)[[2]] <- paste0('x', 1:P)
> y.train=(f(x.train)+sigma*rnorm(N))
> H=20
> x=seq(-3, 3, length.out=H+1)[-(H+1)]
> x.test=matrix(0, nrow=H, ncol=P)
> x.test[ , 5]=x
> ##(L=0.25*(log(N)^(log(log(N)))))
> 
> post = XBART.CLT(cbind(y.train), x.train, x.test,
+                  num_trees=50, num_sweeps=40,
+                  burnin=15)
tau = 1/num_trees, default value. 
mtry = p, use all variables. 
> post$yhat.test=post$yhats_test
> ##post$yhats_test=NULL
> post$yhat.test.mean=apply(post$yhat.test, 1, mean)
> post$yhat.test.025=apply(post$yhat.test, 1, quantile, probs=0.025)
> post$yhat.test.975=apply(post$yhat.test, 1, quantile, probs=0.975)
> plot(x, f(x.test), col='blue', type='l', ylab='f(x)')
> lines(x, post$yhat.test.mean)
> dev.copy2pdf(file='bigdata.pdf')
X11cairo 
       2 
> library(help=XBART)

        Information on package ‘XBART’

Description:

Package:            XBART
Type:               Package
Title:              XBART: Accelerated Bayesian Additive Regression
                    Trees
Version:            0.2
Date:               2019-09-5
Author:             Jingyu He, Saar Yalov, P. Richard Hahn, Lee
                    Reeves
Maintainer:         Jingyu He <jingyu.he@chicagobooth.edu>
Description:        A highly efficient prediction algorithm based on
                    trees.
License:            Apache License (== 2.0)
Imports:            Rcpp (>= 0.12.13)
LinkingTo:          Rcpp, RcppArmadillo
NeedsCompilation:   yes
Packaged:           2020-05-20 21:08:35 UTC; rsparapa
Built:              R 3.5.2; x86_64-pc-linux-gnu; 2020-05-20 21:08:48
                    UTC; unix

Index:

XBART                   XBART: Accelerated Bayesian Additive Regression
                        Trees
XBART-package           XBART: Accelerated Bayesian Additive Regression
                        Trees
XBART.CLT               XBART: Accelerated Bayesian Additive Regression
                        Trees
XBART.Probit            XBART: Accelerated Bayesian Additive Regression
                        Trees

> sessionInfo()
R version 3.5.2 (2018-12-20)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: CentOS Linux 7 (Core)

Matrix products: default
BLAS: /usr/lib64/libblas.so.3.4.2
LAPACK: /usr/lib64/liblapack.so.3.4.2

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] XBART_0.2

loaded via a namespace (and not attached):
[1] compiler_3.5.2 tools_3.5.2    Rcpp_1.0.4    
> 
rsparapa commented 4 years ago

Are you planning to respond to these at some point? It is almost 6 months later.