tidymodels / planning

Documents to plan and discuss future development
MIT License
37 stars 4 forks source link

Transfer Learning via tidymodels #22

Open uriahf opened 2 years ago

uriahf commented 2 years ago

FIrst of all, thank you everyone for your hard work. I'm addicted to tidymodels and a heavy user since I discovered it.

I'm not sure if that's the right place, but I wonder if you ever thought about implementing transfer model via tidymodels workflow?

EmilHvitfeldt commented 2 years ago

Hello @uriahf 👋 Thank you for your interest in tidymodels!

could you give us some examples or methods that you wish to use but is unable to do within the current tidymodels set of packages?

uriahf commented 2 years ago

I will look for good reproducible example, it might take a while.

Thank you!

juliasilge commented 2 years ago

I don't think we specifically need a reproducible example (we don't need a reprex) but if you can point us to specifically the method or type of analysis you want to do, that would be super helpful!

uriahf commented 2 years ago

Well, I made a reprex anyway 😅

I have an important feature which is missing (not at random) and I want to use it directly: I don't want to use imputation because that might mess things up with the explainability of the model and I definitely don't want to use categorization because of the loss of information.

I think that the titanic data set might be a good example: Age is an important feature which is missing for some observations.

The workflow goes like this:

1. Training a "thin model" on the train set - for all observations, only with the features that are full. (Fare and Sex).

2. Adding predictions from the "thin model" as a separate predictor in the train set and the test set thin_model_preds (Does not contain missing values).

3. Training a "full model" on the train set - only for observations with Age, using Fare, Sex (like in the thin model), Age and the predictions from the "thin model" thin_model_preds.

4. Adding predictions from the "full model" to both train and test sets full_model_preds (Contain missing values for observations without age).

5. Choosing predictions from the thin model thin_model_preds for observations without Age and observations from the full model full_model_preds for observations with Age as final predictions final_preds.

I hope that's clear enough.

Here is the reprex:

library(titanic)
library(magrittr)

data(titanic_train)
data(titanic_test)

titanic_train <- titanic_train %>% dplyr::select(Age, Fare, Sex, Survived)
titanic_test <- titanic_test %>% dplyr::select(Age, Fare, Sex)

# 1. Training a thin model for all observations

thin_model <- glm(
  Survived ~ Fare + Sex,
  data = titanic_train,
  family = "binomial"
)

# 2. Adding thin model predictions to all observations in train and test sets

titanic_train_with_thin_model_predictions <- predict(thin_model,
  type = "response",
  newdata = titanic_train
) %>%
  tibble::tibble("thin_model_preds" = .) %>%
  dplyr::bind_cols(titanic_train)

titanic_test_with_thin_model_predictions <- predict(thin_model,
  type = "response",
  newdata = titanic_test
) %>%
  tibble::tibble("thin_model_preds" = .) %>%
  dplyr::bind_cols(titanic_test)

# 3. Training full model only for observations with age
# while using age and thin model predictions as predictors

full_model <- glm(
  Survived ~ Fare + Sex + Age + thin_model_preds,
  data = titanic_train_with_thin_model_predictions,
  family = "binomial"
)
#> Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred

# 4. Adding full model predictions to observations with age in train and test sets

titanic_train_with_predictions <- predict(full_model,
  type = "response",
  newdata = titanic_train_with_thin_model_predictions
) %>%
  tibble::tibble("full_model_preds" = .) %>%
  dplyr::bind_cols(titanic_train_with_thin_model_predictions)

titanic_test_with_predictions <- predict(full_model,
  type = "response",
  newdata = titanic_test_with_thin_model_predictions
) %>%
  tibble::tibble("full_model_preds" = .) %>%
  dplyr::bind_cols(titanic_test_with_thin_model_predictions) 

# 5. Full model predictions as final predictions for observations with age 
# and Thin model predictions as final predictions for observations without age

titanic_train <- titanic_train_with_predictions %>%
  dplyr::mutate(final_preds = ifelse(is.na(full_model_preds), thin_model_preds, 
                                     full_model_preds)) %>% 
  dplyr::select(-c(full_model_preds, thin_model_preds))

titanic_test <- titanic_test_with_predictions %>%
  dplyr::mutate(final_preds = ifelse(is.na(full_model_preds), thin_model_preds, 
                                     full_model_preds)) %>% 
  dplyr::select(-c(full_model_preds, thin_model_preds))

titanic_train
#> # A tibble: 891 x 5
#>      Age  Fare Sex    Survived final_preds
#>    <dbl> <dbl> <chr>     <int>       <dbl>
#>  1    22  7.25 male          0       0.152
#>  2    38 71.3  female        1       0.817
#>  3    26  7.92 female        1       0.687
#>  4    35 53.1  female        1       0.756
#>  5    35  8.05 male          0       0.136
#>  6    NA  8.46 male          0       0.157
#>  7    54 51.9  male          0       0.277
#>  8     2 21.1  male          0       0.261
#>  9    27 11.1  female        1       0.686
#> 10    14 30.1  female        1       0.742
#> # ... with 881 more rows

titanic_test
#> # A tibble: 418 x 4
#>      Age  Fare Sex    final_preds
#>    <dbl> <dbl> <chr>        <dbl>
#>  1  34.5  7.83 male         0.135
#>  2  47    7    female       0.629
#>  3  62    9.69 male         0.107
#>  4  27    8.66 male         0.150
#>  5  22   12.3  female       0.700
#>  6  14    9.22 male         0.174
#>  7  30    7.63 female       0.676
#>  8  26   29    male         0.247
#>  9  18    7.23 female       0.707
#> 10  21   24.2  male         0.234
#> # ... with 408 more rows
juliasilge commented 2 years ago

I'm going to move this to our planning repo so we can collect/track interest in an approach like this.

uriahf commented 2 years ago

Thank you so much!

caimiao0714 commented 1 year ago

I notice that there is already an R package for implementing transfer learning: glmtrans. I'm not sure if this would make incorporating transfer learning in tidymodels a bit easier.