Closed cnellington closed 2 years ago
Added the base_y_predictor and base_param_predictor kwargs for contextualized linear models. Adding a base_y_predictor makes the contextualized module predict only non-base effects, while adding a base_param_predictor will include the base effects in the contextualized parameter estimates. Or similarly, adding a base_param_predictor may help with learning heterogeneous effects but still includes homogeneous (base) effects in predictions, while adding a base_y_predictor will only learn/predict heterogeneous (non-base) effects.
These kwargs should be mutually exclusive in practice, but they're both used in the tests since the dummy classes are only there to help check tensor shapes and make sure we don't break torch autograd.
Agreed, I think this type of warning extends to many of our kwargs so I'll make an issue to sanity-check arguments in all contextualized models.
Added a base predictor kwarg to ContextualizedRegressionBase and enabled for all
regression
andeasy
classes. Added base predictor kwarg quicktests toregression
andeasy
Resolves #89 forregression
andeasy