tobigithub / caret-machine-learning

Practical examples for the R caret machine learning package
MIT License
67 stars 50 forks source link

brnn prediction fails - Error in predict.brnn(modelFit, as.matrix(newdata)) : #17

Open tobigithub opened 8 years ago

tobigithub commented 8 years ago

brnn prediction fails in caret, all other (70) regression methods work. caret_6.0-58 and brnn_0.5

> p1 <- predict(t1 ,testing_data)
Error in predict.brnn(modelFit, as.matrix(newdata)) : 
  The number of predictors used to fit the model and those in `newdata' does not match

Full code


# load caret and DT the cars data set
require(caret); require(DT);  data(cars);

# get all model names just as example
m <- unique(modelLookup()[modelLookup()$forReg,c(1)])

# fill variable m with the fast working models  
m <- c( "brnn",    "cubist")

# load all packages (does not really work due to other dependencies)
suppressPackageStartupMessages(ll <-lapply(m, require, character.only = TRUE))

# define error function
myTryCatch <- function(expr) {
  warn <- err <- NULL
  value <- withCallingHandlers(
    tryCatch(expr, error=function(e) {
      err <<- e
      NULL
    }), warning=function(w) {
      warn <<- w
      invokeRestart("muffleWarning")
    })
  list(value=value, warning=warn, error=err)
}

# define x and y for regression of a synthetic dataset
# defined in lasso 
training_data <- SLC14_1(500)
testing_data <- SLC14_1(10^4)
y <- training_data$y
x <- training_data[, -ncol(training_data)]

# load all libraries
library(doParallel); cl <- makeCluster(detectCores()); registerDoParallel(cl)

set.seed(123);
t1  <- train(y=y, x=x, "brnn", trControl = trainControl(method = "cv"))
t2  <- train(y=y, x=x, "cubist", trControl = trainControl(method = "cv"))

getTrainPerf(t1)
getTrainPerf(t2)

# predict cubist first (works)
p2 <- predict(t2 ,testing_data)
# predict brnn (fails)
p1 <- predict(t1 ,testing_data)

# predict with resamples from cv/boot632 for cubist (works)
r2 <- postResample(p2, testing_data$y)
r2

# predict with resamples from cv/boot632 for brnn  
r1 <- postResample(p1, testing_data$y)
r1

### END
tobigithub commented 8 years ago

Solution, only present x-variables to prediction and not the whole matrix including y-values.

# cubist predicted
      RMSE   Rsquared 
11.6894075  0.6768111

# brnn predicted
      RMSE   Rsquared 
13.0454419  0.6026583 

See below.


# load caret and DT the cars data set
require(caret); require(DT);  data(cars);

# get all model names just as example
m <- unique(modelLookup()[modelLookup()$forReg,c(1)])

# fill variable m with the fast working models  
m <- c( "brnn",    "cubist")

# load all packages (does not really work due to other dependencies)
suppressPackageStartupMessages(ll <-lapply(m, require, character.only = TRUE))

# define error function
myTryCatch <- function(expr) {
  warn <- err <- NULL
  value <- withCallingHandlers(
    tryCatch(expr, error=function(e) {
      err <<- e
      NULL
    }), warning=function(w) {
      warn <<- w
      invokeRestart("muffleWarning")
    })
  list(value=value, warning=warn, error=err)
}

# define x and y for regression of a synthetic dataset
# defined in lasso 
training_data <- SLC14_1(500)
testing_data <- SLC14_1(10^4)
y <- training_data$y
x <- training_data[, -ncol(training_data)]

test_y <- testing_data$y
test_x <- testing_data[, -ncol(testing_data)]

# load all libraries
#library(doParallel); cl <- makeCluster(detectCores()); registerDoParallel(cl)

set.seed(123);
t1  <- train(y=y, x=x, "brnn", trControl = trainControl(method = "boot632"))
set.seed(123);
t2  <- train(y=y, x=x, "cubist", trControl = trainControl(method = "boot632"))

getTrainPerf(t1)
getTrainPerf(t2)

# predict cubist first (works)
p2 <- predict(t2 ,test_x)
# predict brnn second (works)
p1 <- predict(t1 ,test_x)

# predict with resamples from cv/boot632 for cubist (works)
r2 <- postResample(p2, testing_data$y)
r2

# predict with resamples from cv/boot632 for brnn  (works)
r1 <- postResample(p1, testing_data$y)
r1

# stop the parallel processing and register sequential front-end
stopCluster(cl); registerDoSEQ();

### END