ecpolley / SuperLearner

Current version of the SuperLearner R package
272 stars 72 forks source link

Question: Adding interactions for `glmnet` when also using other learners. #149

Closed rdiaz02 closed 8 months ago

rdiaz02 commented 8 months ago

(Not a bug, but a question)

I am using a set of models that include, among others, ranger, xgboost, earth, and glmnet, and there are subject-matter reasons to include interactions. ranger/random forest account for interactions; xgboost implicitly accounts for interactions with trees of depth >= 2; earth does too if we use degree >= 2.

For glmnet I've simply copied the original SL.glmnet to, say, SL.glmnet3, and where it had

X <- model.matrix(~-1 + ., X)
newX <- model.matrix(~-1 + ., newX)

I write, say, for 3-way (and 2-way) interactions:

X <- model.matrix(~-1 + .^3, X)
newX <- model.matrix(~-1 + .^3, newX)

This works fine with CV.SuperLearner. However, it does not work properly when I train on a data set and then predict on a different one because of the behavior in predict.SL.glmnet (easily seen in the code itself and explained in https://github.com/ecpolley/SuperLearner/pull/65). The result is that columns like, say, X1:X2 do not contain X1:X2 but 0.

Passing to the training an X that contains all the interactions is not the way to go (I do not want to do that to ranger, xboost, or earth, for example). I can quickly think of two ways of trying to work around this issue:

  1. Make my SL.glmnet2, SL.glmnet2 return objects of class SL.glmnet2, .... Then, create a bunch of predict.SL.glmnet2, predict.SL.glmnet3, ..., and in these functions, where it says
newdata <- model.matrix(~-1 + ., newdata)

write

newdata <- model.matrix(~-1 + .^2, newdata)

(or 3, or whatever, as appropriate).

  1. Try to be smart inside the predict.SL.glmnet and expand as appropriate.

Option 1. is an ugly kludge, but is easy to do. Option 2. seems much more elegant, but I think there are multiple places where I can make mistakes, some of which might I might not event see or anticipate.

Has anyone dealt with this before? Any comments? Thanks in advance.

ecpolley commented 8 months ago

I just use option 1. You could try to make something that looks at the 'call' object returned by glmnet and parse out the formula, but that does not seem easy to do and likely cause other problems.

rdiaz02 commented 8 months ago

Thanks a lot for the advice! Closing this.