JingyuHe / XBART

88 stars 25 forks source link

Xbart resetting RNG seed? #144

Open lmiratrix opened 1 year ago

lmiratrix commented 1 year ago

We were working on a simulation and are generating data and fitting XBart multiple times. We found that XBART on a mac is somehow resetting or clobbering the seed so the random numbers replicate exactly with each call.

This is demonstrated in the attached code. Note how the 2nd and 3rd generated data sets are exactly the same when xbart is called, but are different when xbart is not called.


library( tidyverse )

#devtools::install_github("soerenkuenzel/causalToolbox")
library( causalToolbox )

set.seed(101010 )

make_data <- function( N, K=5 ) {
  dat = MASS::mvrnorm( N, 
                       mu = rep( 1, K ),
                       Sigma = diag( rep( 1, K ) ) ) %>%
    as.data.frame() 
  colnames( dat ) = paste0( "X", 1:ncol(dat) )
  dat$Z = as.numeric( sample( N ) <= N/2 )
  dat$Y_tau = dat$Z * dat$X3 + rnorm( N )
  dat$Y = dat$X5 + dat$X2 * dat$X3 + dat$Y_tau + rnorm( N )

  return( dat )
}

master_test = make_data( 100 )

perform_simulation_test = function( S = 3, # Number of runs
                                    test_set, 
                                    size_train = 500,
                                    do_bart = TRUE ) { 

  # Validation data (Test data, only covariates)
  x_val = test_set %>%
    dplyr::select( starts_with( "X" ) ) %>%
    as.matrix( )
  nval = nrow( x_val )
  d_val = test_set$Z
  y_val = test_set$Y

  matrix_list <- list() # PP testing
  for (i in 1:S){
    print("RNG state")
    print(.Random.seed)
    # Generate training data
    train_set = make_data( size_train )

    # Separate out into its pieces

    x_tr = train_set %>%
      dplyr::select( starts_with( "X" ) ) %>%
      as.matrix()

    # Train data, only covariates
    d_tr = train_set$Z

    # Train data, only outcomes
    y_tr = as.matrix( train_set$Y )

    index = caret::createFolds(y_tr, k = 2)

    if ( do_bart ) {
      cat( glue::glue( "Fitting bart now on {length(y_tr)} points, validating on {nval} points..." ) )
      cate_esti_bart = BART::mc.wbart(
        x.train = x_tr,
        y.train = y_tr,
        x.test = x_val,
        ndpost = 1000,
        ntree = 100)$yhat.test
      cat( " ...finished.\n" )

    }

    variable_name <- paste0("run_", i) # PP testing
    matrix_list[[variable_name]] <- train_set # PP testing

    cat( "loop done\n" )
  }  

  # Return list of generated training sets.
  return(matrix_list)
} 

## Run the simulation -----------------------------------------

cat( "Running with Bart\n" )

tictoc::tic()
sim_w_XBART = perform_simulation_test(test_set = master_test,
                                      size_train = 50)
tictoc::toc()

## Check Sim Results -------------------------------------------
train_set_1 = sim_w_XBART[['run_1']]
train_set_2 = sim_w_XBART[['run_2']]
train_set_3 = sim_w_XBART[['run_3']]

XBart clobbers seed.R.zip