rstudio / tfestimators

R interface to TensorFlow Estimators
https://tensorflow.rstudio.com/tfestimators
57 stars 21 forks source link

Predict function predicts entire dataset instead of test data #175

Open nagdevAmruthnath opened 5 years ago

nagdevAmruthnath commented 5 years ago

When I execute this sample code below, everything works with out an error. When I try to use predict function with test data, it results in predicting everything instead of just test data set

# load libraries
library(tfestimators)
library(tensorflow)
library(caret)

# load data
data("mtcars")

row.names(mtcars) = NULL
# create data partition
samples = createDataPartition(mtcars$mpg, p= 0.8)
mpg_train = mtcars[samples$Resample1,]
mpg_test = mtcars[-samples$Resample1, ]

# return an input_fn for a given subset of data
CA_input_fn =  function(data, num_epochs = 1) {
  input_fn(mtcars, 
           features = names(mtcars[,2:11]), 
           response = "mpg",
           batch_size = 32,
           num_epochs = num_epochs)
}

# create feature columns
cols <- feature_columns(
  tf$feature_column$numeric_column("disp"),
  tf$feature_column$numeric_column("cyl")
)
# dnn model
model = dnn_regressor(hidden_units = c(30,20,10)
                       ,feature_columns = cols
                       )

# train the model
model %>% train(CA_input_fn(mpg_train, num_epochs = 10))

# do evaluation
model %>% evaluate(CA_input_fn(mpg_test))
# A tibble: 1 x 5
# average_loss `label/mean`  loss `prediction/mean` global_step
# <dbl>        <dbl> <dbl>             <dbl>       <dbl>
#   1         166.         20.1 5302.              13.4          10

# dimension of test data
dim(mpg_test)
## [1]  4 11

# do predictions
predictions = model %>% predict(CA_input_fn(mpg_test), simplify = TRUE, predict_keys = "predictions")

# dimension of predictions data
dim(predictions)
## [1] 32  1

Session Information

> sessionInfo()
R version 3.6.1 (2019-07-05)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 18.04.3 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/openblas/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/libopenblasp-r0.2.20.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8    LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C             LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] h2o_3.22.1.1       densoDA_0.1.1      tensorflow_1.13.1  caret_6.0-84       ggplot2_3.2.0      lattice_0.20-38   
[7] tfestimators_1.9.1

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.2         lubridate_1.7.4    forge_0.2.0        prettyunits_1.0.2  class_7.3-15       utf8_1.1.4        
 [7] zeallot_0.1.0      assertthat_0.2.1   ipred_0.9-9        psych_1.8.12       foreach_1.4.4      R6_2.4.0          
[13] plyr_1.8.4         backports_1.1.4    stats4_3.6.1       pillar_1.4.2       tfruns_1.4         rlang_0.4.0       
[19] progress_1.2.2     lazyeval_0.2.2     rstudioapi_0.10    data.table_1.12.2  whisker_0.3-2      rpart_4.1-15      
[25] Matrix_1.2-17      reticulate_1.12    splines_3.6.1      gower_0.2.0        stringr_1.4.0      foreign_0.8-71    
[31] RCurl_1.95-4.12    munsell_0.5.0      compiler_3.6.1     xfun_0.6           pkgconfig_2.0.2    base64enc_0.1-3   
[37] mnormt_1.5-5       nnet_7.3-12        tidyselect_0.2.5   tibble_2.1.3       prodlim_2018.04.18 codetools_0.2-16  
[43] fansi_0.4.0        crayon_1.3.4       dplyr_0.8.3        withr_2.1.2        MASS_7.3-51.4      bitops_1.0-6      
[49] recipes_0.1.5      ModelMetrics_1.2.2 grid_3.6.1         nlme_3.1-140       jsonlite_1.6       gtable_0.3.0      
[55] magrittr_1.5       scales_1.0.0       cli_1.1.0          stringi_1.4.3      reshape2_1.4.3     doParallel_1.0.14 
[61] timeDate_3043.102  vctrs_0.2.0        generics_0.0.2     lava_1.6.5         iterators_1.0.10   tools_3.6.1       
[67] glue_1.3.1         purrr_0.3.2        hms_0.5.0          parallel_3.6.1     survival_2.44-1.1  yaml_2.2.0        
[73] colorspace_1.4-1   knitr_1.22    

Any ideas on why this bug is happending?

nagdevAmruthnath commented 5 years ago

One solution I came up with was unlist(predictions$predictions)[-samples$Resample1]. It would be great if could be included within predict() fucntion