vdorie / stan4bart

Uses Stan sampler and math library to semiparametrically fit linear and multilevel models with additive Bayesian Additive Regression Tree (BART) components.
40 stars 4 forks source link

Predict function not working #13

Open sreedta8 opened 1 year ago

sreedta8 commented 1 year ago

Here is my training model:

fit40 <- stan4bart(
  formula = sales ~
    hdummys+tv_ads+dig_ads+prt_ads+ # linear component ("fixef")
    (1|dmaseqid) + # multilevel ("ranef") #damaseqid is a factor variable
    bart(. -region -coupons -hdummys -tv_ads -dig_ads -prt_ads), # use bart for other variables
  verbose = -1, # suppress ALL output
  # low numbers for illustration
  data = train, # 8400 rows
  chains = 1, iter = 100, bart_args = list(n.trees = 5,keepTrees = TRUE)) # using only 1 chain

this runs without a problem. Then I use the predict function as follows:

predict(fit40, newdata=test, type = c("ev", "ppd", "indiv.fixef", "indiv.ranef","indiv.bart"), # test data has 2520 rows
combine_chains = FALSE, # has only 1 chain, no need to combine
sample_new_levels = TRUE)

I get the following error:

Warning message in validateXTest(newdata, attr(data@x, "term.labels"), ncol(data@x), :
“column names of 'test' does not equal that of 'x': 'dmaseqid.1, dmaseqid.2, dmaseqid.3, dmaseqid.4, dmaseqid.5, dmaseqid.6, dmaseqid.7, dmaseqid.8, dmaseqid.9, dmaseqid.10, dmaseqid.11, dmaseqid.12, dmaseqid.13, dmaseqid.14, dmaseqid.15, dmaseqid.16, dmaseqid.17, dmaseqid.18, dmaseqid.19, dmaseqid.20, dmaseqid.21, dmaseqid.22, dmaseqid.23, dmaseqid.24, dmaseqid.25, dmaseqid.26, dmaseqid.27, dmaseqid.28, dmaseqid.29, dmaseqid.30, dmaseqid.31, dmaseqid.32, dmaseqid.33, dmaseqid.34, dmaseqid.35, dmaseqid.36, dmaseqid.37, dmaseqid.38, dmaseqid.39, dmaseqid.40, dmaseqid.41, dmaseqid.42, dmaseqid.43, dmaseqid.44, dmaseqid.45, dmaseqid.46, dmaseqid.47, dmaseqid.48, dmaseqid.49, dmaseqid.50, dmaseqid.51, dmaseqid.52, dmaseqid.53, dmaseqid.54, dmaseqid.55, dmaseqid.56, dmaseqid.57, dmaseqid.58, dmaseqid.59, dmaseqid.60, dmaseqid.61, dmaseqid.62, dmaseqid.63, dmaseqid.64, dmaseqid.65, dmaseqid.66, dmaseqid.67, dmaseqid.68, dmaseqid.69, dmaseqid.70, dmaseqid.71, dmaseqid.72, dmaseqid.73, dmaseqid.74, dmaseqid.75, dmaseqid.76, dmaseqid.77, dmaseqid.78, dmaseqid.79, dmaseqid.80, dmaseqid.81, dmaseqid.82, dmaseqid.83, dmaseqid.84, dmaseqid.85, dmaseqid.86, dmaseqid.87, dmaseqid.88, dmaseqid.89, dmaseqid.90, dmaseqid.91, dmaseqid.92, dmaseqid.93, dmaseqid.94, dmaseqid.95, dmaseqid.96, dmaseqid.97, dmaseqid.98, dmaseqid.99, dmaseqid.100, dmaseqid.101, dmaseqid.102, dmaseqid.103, dmaseqid.104, dmaseqid.105, dmaseqid.106, dmaseqid.107, dmaseqid.108, dmaseqid.109, dmaseqid.110, dmaseqid.111, dmaseqid.112, dmaseqid.113, dmaseqid.114, dmaseqid.115, dmaseqid.116, dmaseqid.117, dmaseqid.118, dmaseqid.119, dmaseqid.120, dmaseqid.121, dmaseqid.122, dmaseqid.123, dmaseqid.124, dmaseqid.125, dmaseqid.126, dmaseqid.127, dmaseqid.128, dmaseqid.129, dmaseqid.130, dmaseqid.131, dmaseqid.132, dmaseqid.133, dmaseqid.134, dmaseqid.135, dmaseqid.136, dmaseqid.137, dmaseqid.138, dmaseqid.139, dmaseqid.140, dmaseqid.141, dmaseqid.142, dmaseqid.143, dmaseqid.144, dmaseqid.145, dmaseqid.146, dmaseqid.147, dmaseqid.148, dmaseqid.149, dmaseqid.150, dmaseqid.151, dmaseqid.152, dmaseqid.153, dmaseqid.154, dmaseqid.155, dmaseqid.156, dmaseqid.157, dmaseqid.158, dmaseqid.159, dmaseqid.160, dmaseqid.161, dmaseqid.162, dmaseqid.163, dmaseqid.164, dmaseqid.165, dmaseqid.166, dmaseqid.167, dmaseqid.168, dmaseqid.169, dmaseqid.170, dmaseqid.171, dmaseqid.172, dmaseqid.173, dmaseqid.174, dmaseqid.175, dmaseqid.176, dmaseqid.177, dmaseqid.178, dmaseqid.179, dmaseqid.180, dmaseqid.181, dmaseqid.182, dmaseqid.183, dmaseqid.184, dmaseqid.185, dmaseqid.186, dmaseqid.187, dmaseqid.188, dmaseqid.189, dmaseqid.190, dmaseqid.191, dmaseqid.192, dmaseqid.193, dmaseqid.194, dmaseqid.195, dmaseqid.196, dmaseqid.197, dmaseqid.198, dmaseqid.199, dmaseqid.200, dmaseqid.201, dmaseqid.202, dmaseqid.203, dmaseqid.204, dmaseqid.205, dmaseqid.206, dmaseqid.207, dmaseqid.208, dmaseqid.209, dmaseqid.210, hdummys, tv_ads, dig_ads, prt_ads, region, coupons'; match will be made by position”

Error in dimnames(indiv.bart) <- list(observation = NULL, sample = NULL, : length of 'dimnames' [3] must match that of 'dims' [2]
Traceback:

1. predict(fit40, newdata = test, type = c("ev", "ppd", "indiv.fixef", 
 .     "indiv.ranef", "indiv.bart"), combine_chains = FALSE, sample_new_levels = TRUE)
2. predict.stan4bartFit(fit40, newdata = test, type = c("ev", "ppd", 
 .     "indiv.fixef", "indiv.ranef", "indiv.bart"), combine_chains = FALSE, 
 .     sample_new_levels = TRUE)

Does this have a solution? My train and test data frames have the exactly the same columns, just the number of rows are different. I read here by using a single chain we can overcome the error that comes up with number of dimensions associated with bart component.

sreedta8 commented 1 year ago

@vdorie would you be able to help with my issue above?