AdrianAntico / AutoQuant

R package for automation of machine learning, forecasting, model evaluation, and model interpretation
GNU Affero General Public License v3.0
235 stars 43 forks source link

AutoCatBoostCARMA problems with t + 2 predictions #20

Closed MislavSag closed 5 years ago

MislavSag commented 5 years ago

Hi again,

I tried your AutoCatBoostCARMA function. It seems, there is something wrong with t +2,.. predictions. Here is sample of my data:

structure(list(index = structure(c(13880, 13881, 13882, 13885, 
13886, 13887, 13888, 13889, 13892, 13893, 13894, 13895, 13896, 
13899, 13900, 13901, 13902, 13903, 13906, 13907), class = "Date"), 
    zadnja = c(351.75, 347, 348, 342, 339, 339.86, 342.61, 345, 
    340, 336.11, 331, 333.94, 330.01, 317, 313, 313.98, 315, 
    319.45, 313, 316)), row.names = c(NA, -20L), ticker = "HT", index_quo = ~index, index_time_zone = "UTC", class = c("tbl_time", 
"tbl_df", "tbl", "data.frame"))

And here is your function:

AutoCatBoostCARMA_forecast <-  RemixAutoML::AutoCatBoostCARMA(
  data = sample,
  TargetColumnName = "zadnja",
  DateColumnName = "index",
  FC_Periods = 5,
  TimeUnit = "day",
  TargetTransformation = TRUE,
  Lags = c(1:5)
)
AutoCatBoostCARMA_forecast$Forecast

Results are:

           index Predictions
   1: 2008-01-02          NA
   2: 2008-01-03          NA
   3: 2008-01-04          NA
   4: 2008-01-07          NA
   5: 2008-01-08          NA
  ---                       
2836: 2019-07-05    159.5785
2837: 2019-07-06          NA
2838: 2019-07-06     -1.0000
2839: 2019-07-07          NA
2840: 2019-07-07     -1.0000

For t +2 and forward results are NA and -1.

The same thing happens on bigger sample.

P.S. I would like to add LSTM time series prediction code to your arsenal. Do you agree with that and do you have some incorporate new models in your code?

AdrianAntico commented 5 years ago

@MislavSag Just made the fix. Looks like there was an issue with the updating step for the non-grouping case, along with something in the return portion at the end of the function. Your code should run fine now. Also, I'm okay with others adding functions to the package. I'll want to look through what you're submitting to ensure it fits nicely in the "AutoML" concept. If everything is good, I'll create the necessary documentation and add you as a contributor. If not, I'll usually request enhancements and can work with you on that if you want.

MislavSag commented 5 years ago

Now I get an error:

Error in val[!neg_idx] <- (x[!neg_idx] * lambda + 1)^(1/lambda) - 1 : 
  NAs are not allowed in subscripted assignments

... Ok, I will write a function and send you so you can try it out.

AdrianAntico commented 5 years ago

I think the problem is with the transformation. A few things with that: on non-grouped data, the transformation is usually a bad choice as the predictions can be really bad whereas for grouped data, the transformation almost always works better. Second, when I ran through a simple non-grouping example, the transformed values were all the same afterwards. So I'd try again with TargetTransformation set to FALSE and see if that works.

I'll be around today to take a look into your function. Can't wait to see it!

AdrianAntico commented 5 years ago

I might end up automatically turning off the transformation functionality unless grouping variables are included. Need to test out some more, however, as I'm not sure if it's simply not having a grouping variable or if there just isn't enough total data.

MislavSag commented 5 years ago

After some time, I have tried the function again. I have only 2 columns: date and numeric vector. I don't use grouping. It looks like this:

Results <- AutoCatBoostCARMA(x,
                             TargetColumnName = "value,
                             DateColumnName = "datum",
                             GroupVariables = NULL,
                             FC_Periods = forecast_periods,
                             TimeUnit = "day",
                             TargetTransformation = FALSE,
                             Lags = c(1:10),
                             MA_Periods = c(1:10),
                             CalendarVariables = TRUE,
                             HolidayVariable = TRUE,
                             TimeTrendVariable = TRUE,
                             DataTruncate = FALSE,
                             SplitRatios = c(0.8, 0.1, 0.1),
                             TaskType = "GPU",
                             EvalMetric = "MAE",
                             GridTune = FALSE,
                             GridEvalMetric = "mae",
                             ModelCount = 1,
                             NTrees = 500,
                             PartitionType = "timeseries",
                             Timer = TRUE)

Forecasted values are much away from lat observation. Form example, if last observation is 500, forecasts are 300 or even 7.

AdrianAntico commented 5 years ago

Same data as before?

AdrianAntico commented 5 years ago

Can you try using RMSE instead of MAE for EvalMetric? I noticed an issue with CatBoost when using MAE...

MislavSag commented 5 years ago

Yes, same data set. Only deference is that this time I set TargetTransformation = FALSE. It behaves the same on larger data set.

MislavSag commented 5 years ago

Whats the deference between your AutoH2oGBMCARMA() and H2o automl function?

AdrianAntico commented 5 years ago

Have you read through the readme on the main github page? There are sections for Supervised Learning Functions and Time Series Modeling Functions and the top of those sections I have some notes about what each of core sets of functions are doing internally. The supervised learning functions are for regression, binary and multinomial classification. The Time Series functions are for forecasting. The AutoH2oGBMCARMA() runs just like the AutoCatBoostCARMA() function except it utilizes H2O's GBM algorithm instead of CatBoost. They both utilize the existing automl functions for regression internally, along with several others. In essence, the CARMA functions and the Hurdle functions are composites of several of the other supervised learning functions, scoring functions, and feature engineering functions.

DougVegas commented 5 years ago

Whats the deference between your AutoH2oGBMCARMA() and H2o automl function?

Per Adrian's reply, the AutoH2oGBMCARMA() function in RemixAutoML leverages the AutoH2oGBMRegression() function from RemixAutoML under the hood.

The difference between AutoH2oGBMRegression from RemixAutoML and the h2o.automl from H2O is that the former uses the H2O GBM model and fully optimizes it and the latter can use several other models in addition to GBM (Random Forest, Deep Learning, Stacked Ensemble) but may not optimize fully.

In addition, AutoH2oGBMRegression from RemixAutoML does several steps that h2o.automl in H2O does NOT do (such as automatic data partitioning, data type conversions, target variable transformations, etc). For a full list of the steps that AutoH2oGBMRegression, see the README.md file on the main page under "Regression" in "Supervised Learning Functions"

MislavSag commented 5 years ago

OK, thanks for explanation.

what about my first question. Why do I get so strange predictions for univariate time series when using AutoH2oGBMCARMA()?

AdrianAntico commented 5 years ago

I'm pretty sure that's happening because you have the EvalMetric set to MAE which is currently not working on the CatBoost side (I assume you mean AutoCatBoostCARMA based on your previous post). Have you tried using RMSE instead? I had to switch a few models to RMSE because of that. I'm keeping an eye on the CatBoost issue log to see when they fix that.

MislavSag commented 5 years ago

I tried to use RMSE and got the same strange results. But when I tried mannualy:

TestModel <- AutoCatBoostRegression(
  data = x[1:as.integer(0.8 * nrow(x)), ],
  ValidationData = x[(as.integer(0.8* nrow(x)) + 1):as.integer(0.9* nrow(x)), ],
  TestData = x[(as.integer(0.9 * nrow(x)) + 1):nrow(x), ],
  TargetColumnName = "zadnja",
  FeatureColNames = "zadnja",
  PrimaryDateColumn = "datum",
  IDcols = NULL,
  TransformNumericColumns = NULL,
  MaxModelsInGrid = 10,
  task_type = "GPU",
  eval_metric = "RMSE",
  grid_eval_metric = "poisson",
  Trees = 500,
  GridTune = FALSE,
  model_path = getwd(),
  metadata_path = getwd(),
  ModelID = "ModelTest",
  NumOfParDepPlots = 1,
  ReturnModelObjects = TRUE,
  SaveModelObjects = TRUE,
  PassInGrid = NULL
)

Preds <- RemixAutoML::AutoCatBoostScoring(
  TargetType = "regression",
  ScoringData = x,
  FeatureColumnNames = NULL,
  IDcols = NULL,
  ModelObject = TestModel$Model,
  ModelPath = getwd(),
  ModelID = "ModelTest",
  ReturnFeatures = TRUE,
  TransformNumeric = FALSE,
  BackTransNumeric = FALSE,
  TransformationObject = NULL,
  TransID = NULL,
  TransPath = NULL,
  TargetColumnName = "zadnja",
  MDP_Impute = TRUE,
  MDP_CharToFactor = TRUE,
  MDP_RemoveDates = TRUE,
  MDP_MissFactor = "0",
  MDP_MissNum = -1
)

first prediction is "understandable" (didn't tried following predictions (t +2...) for updated tables).

AdrianAntico commented 5 years ago

Can you paste in or link the source data? I'll give it shot when I get home from work and see if I can get to the bottom of this.

AdrianAntico commented 5 years ago

@MislavSag

After thinking more about this, I think the problem may lie with the data and what the AutoCARMA functions can currently support. Currently, the data must be complete, as in, no missing dates in the time series interval. With stock data, there is typically only 5 days a week of data, then it skips the weekends, and then continues. I have plans to handle that type of data eventually.

MislavSag commented 5 years ago

Sample:

x <- structure(list(datum = structure(c(13880, 13881, 13882, 13885, 
                                   13886, 13887, 13888, 13889, 13892, 13893, 13894, 13895, 13896, 
                                   13899, 13900, 13901, 13902, 13903, 13906, 13907, 13908, 13909, 
                                   13910, 13913, 13914, 13915, 13916, 13917, 13920, 13921, 13922, 
                                   13923, 13924, 13927, 13928, 13929, 13930, 13931, 13934, 13935, 
                                   13936, 13937, 13938, 13941, 13942, 13943, 13944, 13945, 13948, 
                                   13949, 13950, 13951, 13952, 13955, 13956, 13957, 13958, 13963, 
                                   13964, 13965, 13966, 13969, 13970, 13971, 13972, 13973, 13976, 
                                   13977, 13978, 13979, 13980, 13983, 13984, 13985, 13986, 13987, 
                                   13990, 13991, 13992, 13993, 13994, 13997, 13998, 13999, 14004, 
                                   14005, 14006, 14007, 14008, 14011, 14012, 14013, 14014, 14015, 
                                   14018, 14019, 14020, 14022, 14025, 14026, 14027, 14028, 14029, 
                                   14032, 14033, 14034, 14035, 14036, 14039, 14040, 14041, 14042, 
                                   14043, 14046, 14047, 14048, 14049, 14050, 14053, 14054, 14056, 
                                   14057, 14060, 14061, 14062, 14063, 14064, 14067, 14068, 14069, 
                                   14070, 14071, 14074, 14075, 14076, 14077, 14078, 14081, 14082, 
                                   14083, 14084, 14085, 14088, 14089, 14090, 14091, 14092, 14097, 
                                   14098, 14099, 14102, 14103, 14104, 14105, 14109, 14110, 14111, 
                                   14112, 14113, 14116, 14117, 14118, 14119, 14120, 14123, 14124, 
                                   14125, 14126, 14127, 14130, 14131, 14132, 14133, 14134, 14137, 
                                   14138, 14139, 14140, 14141, 14144, 14145, 14146, 14147, 14148, 
                                   14151, 14152, 14153, 14154, 14155, 14158, 14159, 14161, 14162, 
                                   14165, 14166, 14167, 14168, 14169, 14172, 14173, 14174, 14175, 
                                   14176, 14179, 14180, 14181, 14182, 14183, 14186, 14187, 14188, 
                                   14189, 14190, 14193, 14194, 14195, 14196, 14197, 14200, 14201, 
                                   14202, 14203, 14204, 14207, 14208, 14209, 14210, 14211, 14214, 
                                   14215, 14216, 14217, 14218, 14221, 14222, 14223, 14224, 14225, 
                                   14228, 14229, 14230, 14231, 14232, 14235, 14236, 14237, 14242, 
                                   14243, 14244, 14249, 14251, 14252, 14253, 14256, 14257, 14258, 
                                   14259, 14260, 14263, 14264, 14265, 14266, 14267, 14270, 14271, 
                                   14272, 14273, 14274, 14277, 14278, 14279, 14280, 14281, 14284, 
                                   14285, 14286, 14287, 14288, 14291, 14292, 14293, 14294, 14295, 
                                   14298, 14299, 14300, 14301, 14302, 14305, 14306, 14307, 14308, 
                                   14309, 14312, 14313, 14314, 14315, 14316, 14319, 14320, 14321, 
                                   14322, 14323, 14326, 14327, 14328, 14329, 14330, 14333, 14334, 
                                   14335, 14336, 14337, 14340, 14341, 14342, 14343, 14348, 14349, 
                                   14350, 14351, 14354, 14355, 14356, 14357, 14358, 14361, 14362, 
                                   14363, 14364, 14368, 14369, 14370, 14371, 14372, 14375, 14376, 
                                   14377, 14378, 14379, 14382, 14383, 14384, 14385, 14386, 14389, 
                                   14390, 14391, 14392, 14393, 14396, 14397, 14398, 14399, 14400, 
                                   14403, 14404, 14405, 14407, 14410, 14411, 14412, 14413, 14414, 
                                   14418, 14419, 14421, 14424, 14425, 14426, 14427, 14428, 14431, 
                                   14432, 14433, 14434, 14435, 14438, 14439, 14440, 14441, 14442, 
                                   14445, 14446, 14447, 14448, 14449, 14452, 14453, 14454, 14455, 
                                   14456, 14459, 14460, 14462, 14463, 14466, 14467, 14468, 14469, 
                                   14470, 14473, 14474, 14475, 14476, 14477, 14480, 14481, 14482, 
                                   14483, 14484, 14487, 14488, 14489, 14490, 14491, 14494, 14495, 
                                   14496, 14497, 14498, 14501, 14502, 14503, 14504, 14505, 14508, 
                                   14509, 14510, 14511, 14512, 14515, 14516, 14517, 14518, 14519, 
                                   14522, 14523, 14524, 14529, 14530, 14531, 14532, 14533, 14536, 
                                   14537, 14538, 14539, 14540, 14543, 14544, 14545, 14546, 14547, 
                                   14550, 14551, 14552, 14553, 14554, 14557, 14558, 14559, 14560, 
                                   14561, 14564, 14565, 14566, 14567, 14568, 14571, 14572, 14573, 
                                   14574, 14575, 14578, 14579, 14580, 14581, 14582, 14585, 14586, 
                                   14587, 14588, 14589, 14592, 14593, 14594, 14595, 14596, 14599, 
                                   14600, 14601, 14602, 14606, 14607, 14608, 14609, 14613, 14614, 
                                   14616), class = "Date"), zadnja = c(351.75, 347, 348, 342, 339, 
                                                                       339.86, 342.61, 345, 340, 336.11, 331, 333.94, 330.01, 317, 313, 
                                                                       313.98, 315, 319.45, 313, 316, 316.5, 315, 320, 315, 311.23, 
                                                                       305.55, 298.02, 291.8, 294.98, 296.44, 296, 294, 290.65, 288, 
                                                                       291.99, 295, 310, 303.1, 306.11, 309.51, 312.51, 328.1, 328.1, 
                                                                       324.8, 329.23, 337.01, 333.6, 333, 327.23, 328.5, 328.54, 324.5, 
                                                                       322, 317.01, 318, 319.98, 329.8, 323, 317, 318.55, 319.98, 323.99, 
                                                                       316.09, 315.01, 317.5, 315.03, 312.55, 312, 315, 312.89, 308.5, 
                                                                       295.53, 308, 315, 285.12, 284.34, 285, 281.39, 282.92, 285.94, 
                                                                       284.96, 282.9, 273.5, 273.5, 273.21, 281.14, 286.99, 283, 280.39, 
                                                                       283, 280, 285, 285.02, 289, 288, 284.5, 280.83, 278.3, 274.1, 
                                                                       276, 277, 275.5, 284.49, 294.5, 288.57, 287, 285, 285, 282, 277.56, 
                                                                       278.5, 279.8, 280, 282.5, 284, 280.31, 279, 280, 276.5, 276.53, 
                                                                       273.01, 273.21, 273, 270.17, 270, 269.06, 270.1, 268.51, 269.25, 
                                                                       270.28, 277.9, 271.05, 273.5, 272.4, 272.21, 274.01, 279.99, 
                                                                       282.03, 280.4, 280.32, 285, 283.99, 281.75, 282, 283.3, 284, 
                                                                       283, 282.45, 283.89, 282.18, 283, 283.5, 283.67, 285, 285, 288.99, 
                                                                       288.01, 283.9, 284.04, 283.5, 283.04, 277.01, 280, 281.01, 282, 
                                                                       284, 283.5, 282.21, 281, 280, 279, 277.5, 272, 272.3, 268.99, 
                                                                       260, 258.5, 250, 266.91, 266.99, 259.98, 262, 263.42, 262.7, 
                                                                       259.75, 258.5, 258.98, 256.9, 253, 243.05, 240.01, 228, 218.02, 
                                                                       241.9, 259.89, 242, 248.69, 245, 249.5, 250.02, 242, 235, 227, 
                                                                       220, 221, 223, 226, 226.5, 229.5, 243.02, 237.5, 230, 228.49, 
                                                                       230, 220, 219, 216.1, 218.25, 210, 204.15, 203.1, 193.06, 196.5, 
                                                                       197.16, 202, 198, 200, 196, 194, 190.35, 191, 193, 192, 196.44, 
                                                                       197, 198.11, 197, 194, 195, 195.5, 197, 197.48, 201.49, 207, 
                                                                       211.6, 211.1, 206.17, 200, 199.22, 203.5, 206.53, 211, 218, 228.78, 
                                                                       219.85, 209.25, 214.79, 220.18, 219, 220.04, 225, 222.5, 220, 
                                                                       220.06, 218.31, 222.21, 220, 221, 216.56, 217.02, 220, 221.38, 
                                                                       221.5, 221.5, 222.01, 222.06, 217.71, 222.36, 219.49, 218, 220, 
                                                                       222.48, 220, 219.42, 218.61, 220.5, 221.02, 228.9, 221, 205, 
                                                                       205.72, 206.13, 205, 201.03, 205.5, 205, 205.5, 203.85, 203.83, 
                                                                       203.5, 202.16, 204.24, 202.1, 204.52, 204.1, 206.7, 208.4, 207, 
                                                                       210, 217.99, 218.33, 226, 223.86, 223, 221.49, 219.99, 224, 224, 
                                                                       224, 222.68, 207.5, 207.99, 207, 205.13, 209, 209, 206.32, 204.5, 
                                                                       206.02, 211, 212.99, 210.21, 213.42, 220.1, 220.97, 216.11, 219.99, 
                                                                       216.3, 216.88, 217, 216.57, 217.99, 220.8, 217.51, 219.41, 220.5, 
                                                                       221.84, 223, 223.28, 222, 222.86, 220.25, 218.2, 216, 218, 215.99, 
                                                                       214.97, 215.5, 216.87, 214, 213.01, 213.7, 217.2, 216.73, 218, 
                                                                       217, 218.98, 218.01, 218.5, 219, 218, 218.5, 218, 219.85, 219.98, 
                                                                       218, 217.1, 218, 222, 230.73, 230, 229.07, 229.92, 231, 229, 
                                                                       226.9, 229.49, 228.86, 225.99, 225.5, 225.01, 226.4, 227.8, 224, 
                                                                       225, 225.51, 225.01, 225, 225, 228, 227.62, 226.32, 227, 225.97, 
                                                                       228, 228.33, 229.5, 229.34, 228.9, 228.28, 229.47, 228.06, 227.76, 
                                                                       227.8, 227.48, 227.75, 227.5, 226.91, 228.99, 230, 232.75, 237, 
                                                                       251.3, 258.12, 258.7, 259.73, 258.33, 258.5, 258, 258.2, 257.51, 
                                                                       258.53, 259.5, 260.5, 261, 260.5, 260, 260.11, 260.52, 264.3, 
                                                                       265.49, 270.1, 270.65, 271.65, 272.04, 272.92, 271.61, 272.25, 
                                                                       272.7, 271.31, 272, 271.62, 269, 270, 271.35, 270.1, 270.72, 
                                                                       270.82, 269, 266, 264.02, 267.78, 270, 269.61, 270.01, 270.01, 
                                                                       269.8, 269.8, 269, 267.01, 268.3, 271, 271.61, 272, 270.3, 271, 
                                                                       270.68, 272.45, 273, 271.91, 272.03, 272, 272.31, 272.6, 271.37, 
                                                                       272.7, 271, 270.49, 270.01, 268.05, 272.53, 274.95, 274.7, 274.17, 
                                                                       274, 273.5, 274.5, 276, 278, 276.81)), row.names = c(NA, -500L
                                                                       ), key = structure(list(.rows = list(1:500)), row.names = c(NA, 
                                                                                                                                   -1L), class = c("tbl_df", "tbl", "data.frame")), index = structure("datum", ordered = TRUE), index2 = "datum", interval = structure(list(
                                                                                                                                     year = 0, quarter = 0, month = 0, week = 0, day = 1, hour = 0, 
                                                                                                                                     minute = 0, second = 0, millisecond = 0, microsecond = 0, 
                                                                                                                                     nanosecond = 0, unit = 0), class = "interval"), class = c("tbl_ts", 
                                                                                                                                                                                               "tbl_df", "tbl", "data.frame"))

CARMA function:

Results <- AutoCatBoostCARMA(x,
                             TargetColumnName = "zadnja",
                             DateColumnName = "datum",
                             GroupVariables = NULL,
                             FC_Periods = 1,
                             TimeUnit = "day",
                             TargetTransformation = TRUE,
                             Lags = c(1:10),
                             MA_Periods = c(1:5),
                             CalendarVariables = FALSE,
                             HolidayVariable = FALSE,
                             TimeTrendVariable = FALSE,
                             DataTruncate = FALSE,
                             SplitRatios = c(0.8, 0.1, 0.1),
                             TaskType = "GPU",
                             EvalMetric = "MAE",
                             GridTune = FALSE,
                             GridEvalMetric = "RMSE",
                             ModelCount = 30,
                             NTrees = 100,
                             PartitionType = "timeseries",
                             Timer = TRUE)

I got an error:

Error in data.table::set(PlotData, i = (data[, .N] + 1):PlotData[, .N],  : 
  i[1] is 501 which is out of range [1,nrow=500].

It works if I try more than 1 forecast periods:

Results <- AutoCatBoostCARMA(as.data.table(x),
                             TargetColumnName = "zadnja",
                             DateColumnName = "datum",
                             GroupVariables = NULL,
                             FC_Periods = 3,
                             TimeUnit = "day",
                             TargetTransformation = TRUE,
                             Lags = c(1:10),
                             MA_Periods = c(1:5),
                             CalendarVariables = FALSE,
                             HolidayVariable = FALSE,
                             TimeTrendVariable = FALSE,
                             DataTruncate = FALSE,
                             SplitRatios = c(0.8, 0.1, 0.1),
                             TaskType = "GPU",
                             EvalMetric = "MAE",
                             GridTune = FALSE,
                             GridEvalMetric = "RMSE",
                             ModelCount = 30,
                             NTrees = 100,
                             PartitionType = "timeseries",
                             Timer = TRUE)

but predictions doesn't make sense: Results$Forecast[, 1:3]

      datum     zadnja Predictions

1: 2008-01-02 351.750000 1.388187 2: 2008-01-03 347.000000 1.378253 3: 2008-01-04 348.000000 1.371332 4: 2008-01-07 342.000000 1.358842 5: 2008-01-08 339.000000 1.377182


498: 2010-01-04 276.000000 1.453191 499: 2010-01-05 278.000000 1.452219 500: 2010-01-07 276.810000 1.452452 501: 2010-01-08 1.452551 1.452551 502: 2010-01-09 1.405331 1.405331

AdrianAntico commented 5 years ago

I fixed the FC_Periods = 1 issue. As for the odd predicted values, there's a few things to note:

MislavSag commented 5 years ago

I tried again with your recommendations and now it works fine.

AdrianAntico commented 5 years ago

@MislavSag Thanks for your patience!