tidymodels / parsnip

A tidy unified interface to models
https://parsnip.tidymodels.org
Other
564 stars 78 forks source link

let `fit_xy()` take dgCMatrix input #1121

Open EmilHvitfeldt opened 1 month ago

EmilHvitfeldt commented 1 month ago

Ref: #1125

General idea:

TODO:

library(tidymodels)
library(textrecipes)
library(friends)

preped_rec <- recipe(season ~ text, data = friends) %>%
  step_tokenize(text) %>%
  step_tf(text) %>%
  prep()
#> Warning in asMethod(object): sparse->dense coercion: allocating vector of size
#> 8.7 GiB

term_freq <- bake(preped_rec, new_data = NULL, composition = "dgCMatrix")

dim(term_freq)
#> [1] 67373 17378

lobstr::obj_size(term_freq)
#> 9.86 MB

lm_spec <- linear_reg(penalty = 0) |>
  set_engine("glmnet")

tictoc::tic()
lm_fit <- fit_xy(lm_spec, x = term_freq[, -1], y = term_freq[, 1])
tictoc::toc()
#> 2.006 sec elapsed

lm_fit
#> parsnip model object
#> 
#> 
#> Call:  glmnet::glmnet(x = maybe_matrix(x), y = y, family = "gaussian") 
#> 
#>       Df  %Dev   Lambda
#> 1      0  0.00 0.200600
#> 2      1  0.09 0.182800
#> 3      2  0.22 0.166500
#> 4      3  0.35 0.151700
#> 5      3  0.50 0.138300
#> 6      4  0.63 0.126000
#> 7      7  0.83 0.114800
#> 8      8  1.04 0.104600
#> 9     11  1.26 0.095300
#> 10    13  1.49 0.086830
#> 11    19  1.73 0.079120
#> 12    31  2.04 0.072090
#> 13    39  2.40 0.065680
#> 14    52  2.80 0.059850
#> 15    70  3.22 0.054530
#> 16    81  3.66 0.049690
#> 17   101  4.10 0.045270
#> 18   139  4.58 0.041250
#> 19   193  5.14 0.037590
#> 20   273  5.79 0.034250
#> 21   375  6.52 0.031210
#> 22   515  7.34 0.028430
#> 23   677  8.25 0.025910
#> 24   962  9.26 0.023610
#> 25  1208 10.40 0.021510
#> 26  1516 11.58 0.019600
#> 27  2001 12.83 0.017860
#> 28  2946 14.32 0.016270
#> 29  3538 15.93 0.014820
#> 30  4287 17.51 0.013510
#> 31  5048 19.10 0.012310
#> 32  5607 20.61 0.011210
#> 33  6149 22.00 0.010220
#> 34  6755 23.30 0.009311
#> 35  7295 24.50 0.008483
#> 36  7820 25.59 0.007730
#> 37  8359 26.58 0.007043
#> 38  8846 27.48 0.006417
#> 39  9370 28.30 0.005847
#> 40  9814 29.03 0.005328
#> 41 10265 29.71 0.004855
#> 42 10717 30.31 0.004423
#> 43 11068 30.84 0.004030
#> 44 11432 31.34 0.003672
#> 45 11753 31.77 0.003346
#> 46 12103 32.17 0.003049
#> 47 12389 32.51 0.002778
#> 48 12669 32.82 0.002531
#> 49 12956 33.09 0.002306
#> 50 13223 33.34 0.002101
#> 51 13505 33.55 0.001915
#> 52 13731 33.72 0.001745
#> 53 14016 33.92 0.001590
#> 54 14224 34.07 0.001448
#> 55 14406 34.18 0.001320
#> 56 14667 34.31 0.001203
#> 57 14791 34.41 0.001096
#> 58 14998 34.51 0.000998
#> 59 15103 34.59 0.000910
#> 60 15216 34.65 0.000829
#> 61 15329 34.70 0.000755
#> 62 15449 34.75 0.000688
#> 63 15542 34.78 0.000627
#> 64 15668 34.82 0.000571
#> 65 15739 34.87 0.000521
#> 66 15770 34.89 0.000474
#> 67 15843 34.91 0.000432
#> 68 15894 34.93 0.000394
#> 69 15957 34.94 0.000359
#> 70 16004 34.96 0.000327
#> 71 16044 34.96 0.000298
#> 72 16099 34.97 0.000271
#> 73 16141 34.98 0.000247
#> 74 16174 34.99 0.000225
#> 75 16211 35.00 0.000205
#> 76 16251 35.00 0.000187
#> 77 16339 35.01 0.000170
#> 78 16393 35.02 0.000155
#> 79 16396 35.02 0.000141
#> 80 16393 35.02 0.000129
#> 81 16401 35.03 0.000118
#> 82 16424 35.03 0.000107
#> 83 16450 35.03 0.000098
#> 84 16493 35.03 0.000089
#> 85 16509 35.04 0.000081
#> 86 16501 35.04 0.000074
#> 87 16523 35.04 0.000067
#> 88 16522 35.04 0.000061
#> 89 16528 35.04 0.000056
#> 90 16529 35.04 0.000051
#> 91 16542 35.04 0.000046
#> 92 16555 35.04 0.000042
#> 93 16567 35.04 0.000038
#> 94 16579 35.04 0.000035
#> 95 16649 35.05 0.000032
#> 96 16644 35.05 0.000029