rstudio / sparkxgb

R interface for XGBoost on Spark
https://spark.posit.co/packages/sparkxgb/
Other
47 stars 14 forks source link

sparkxgb assigns wrong probabilities when used with twoclass classification #23

Open mzorko opened 4 years ago

mzorko commented 4 years ago
library(sparkxgb)
library(sparklyr)
library(dplyr)

sc <- spark_connect(master = "local")
iris_tbl <- sdf_copy_to(sc, iris)
iris_tbl <- iris_tbl %>% filter(Species != "setosa")

xgb_model <- xgboost_classifier(
  iris_tbl, Species ~ .,
  num_class = 2,
  num_round = 50, 
  max_depth = 4)

ml_predict(xgb_model, iris_tbl) %>% 
     select(Species, predicted_label, starts_with("probability_"))

# Source: spark<?> [?? x 4]
# Species    predicted_label probability_versicolor probability_virginica
# <chr>      <chr>                            <dbl>                 <dbl>
# 1 versicolor virginica                      0.00313                 0.997
# 2 versicolor virginica                      0.00313                 0.997
# 3 versicolor virginica                      0.0121                  0.988
# 4 versicolor virginica                      0.0120                  0.988
# 5 versicolor virginica                      0.00601                 0.994
# 6 versicolor virginica                      0.00180                 0.998
# 7 versicolor virginica                      0.00368                 0.996
# 8 versicolor virginica                      0.0433                  0.957
# 9 versicolor virginica                      0.00260                 0.997
#10 versicolor virginica                      0.00180                 0.998
kevinykuo commented 4 years ago

Looks like a bug! Do the multiclass (>2) models behave properly?

mzorko commented 4 years ago

I didn't look into multiclass models in detail so far. I was playing with this small example and for n > 2 I couldn't find any problems. For n = 2 probabilities are always swapped.

n <- 2
xgb_tbl <- data.frame(x = sample(1:n, 100, replace = TRUE)) %>% 
  mutate(target = LETTERS[x]) %>%
  sdf_copy_to(sc, ., overwrite = TRUE) 

xgboost_classifier(
  xgb_tbl, target ~ ., num_class = n, num_round = 50, max_depth = 6) %>% 
  ml_predict(xgb_tbl) %>% group_by(x, predicted_label) %>% count %>% 
  arrange(x) %>% collect() %>% data.frame()