mlverse / tft

R implementation of Temporal Fusion Transformers
https://mlverse.github.io/tft/
Other
26 stars 9 forks source link

make predict.tft() compatible with yardstick::metricset() usage #5

Open cregouby opened 3 years ago

cregouby commented 3 years ago

... in order to restore the previous part of README

metrics <- metric_set(accuracy, precision, recall)
cbind(vic_elec_test, predict(fit, vic_elec_test)) %>% 
  metrics(Demand, estimate = .pred_class)

cbind(vic_elec_test, predict(fit, vic_elec_test, type = "prob")) %>% 
  roc_auc(Demand, .pred_No)