NVIDIA / spark-rapids-tools

User tools for Spark RAPIDS
Apache License 2.0
50 stars 37 forks source link

Add shap command to internal CLI for debugging #1197

Closed leewyang closed 2 months ago

leewyang commented 2 months ago

This PR adds a shap command to the internal CLI to help explain a specific (per-sql) XGBoost prediction.

Usage:

python qualx_main.py shap --help

Example:

python --platform $PLATFORM \
--prediction_output /path/to/prediction/output \
--index 0
# --model $MODEL   # optional

# --index should be a numeric zero-based index pointing to a specific line (i.e. sqlID) in the `shap_values.csv` file.
# Each line in this file corresponds to the same line (sqlID) in the `per_sql.csv` file.

The output of the command looks like:

+-----+-------------------------------------------------+--------------+--------------+--------------------+--------------+-------------+-------------+-------------+-------------+-------------+-------------+-----------------+----------------+
|     | feature                                         |   shap_value |   model_rank |   model_shap_value |   train_mean |   train_std |   train_min |   train_25% |   train_50% |   train_75% |   train_max |   feature_value | out_of_range   |
|-----+-------------------------------------------------+--------------+--------------+--------------------+--------------+-------------+-------------+-------------+-------------+-------------+-------------+-----------------+----------------|
|   0 | executorCPUTime_mean                            |      -0.1192 |            0 |             0.1927 |      1.8e+03 |     5.1e+03 |     7.0e+01 |     2.6e+02 |     6.0e+02 |     1.2e+03 |     6.1e+04 |         4.9e+02 | False          |
|   1 | sw_bytesWrittenRatio                            |       0.0478 |            7 |             0.0220 |      9.6e-01 |     1.3e+00 |     1.2e-06 |     6.4e-03 |     3.2e-01 |     1.8e+00 |     1.2e+01 |         2.4e+00 | False          |
|   2 | executorDeserializeCPUTime_mean                 |      -0.0475 |            5 |             0.0237 |      6.7e+00 |     3.1e+00 |     2.1e+00 |     5.6e+00 |     6.2e+00 |     7.3e+00 |     3.7e+01 |         3.9e+01 | True           |
|   3 | sw_recordsWritten_sum                           |      -0.0339 |            1 |             0.0711 |      1.5e+09 |     3.8e+09 |     2.6e+02 |     1.7e+06 |     7.0e+07 |     8.7e+08 |     2.4e+10 |         1.2e+08 | False          |
...
| 106 | sqlOp_CommandResult                             |       0.0000 |          106 |             0.0000 |      0.0e+00 |     0.0e+00 |     0.0e+00 |     0.0e+00 |     0.0e+00 |     0.0e+00 |     0.0e+00 |         0.0e+00 | False          |
| 107 | sqlOp_WindowSort                                |       0.0000 |          107 |             0.0000 |      0.0e+00 |     0.0e+00 |     0.0e+00 |     0.0e+00 |     0.0e+00 |     0.0e+00 |     0.0e+00 |         0.0e+00 | False          |
+-----+-------------------------------------------------+--------------+--------------+--------------------+--------------+-------------+-------------+-------------+-------------+-------------+-------------+-----------------+----------------+
Shap base value: 0.4152
Shap values sum: -0.0120
Shap prediction: 0.4032
exp(prediction): 1.4965

Where:

Changes

  1. Added features.csv to save the feature values used for prediction.
  2. Moved the current shap_values.csv to feature_importance.csv (which is more descriptive of its purpose).
  3. Used shap_values.csv to save all of the shap values per feature per instance/sqlID during prediction.
  4. Saved a model.metrics file (for each model) during training to store the feature shap values and distribution metrics of the training set.
  5. Renamed the model.json.cfg files to model.cfg to avoid the double-suffix.
  6. Refactored/combined the compute_feature_importance and compute_shapley_values functions.
  7. Updated internal predict CLI to support --qual_output argument.
  8. Added shap command to internal CLI, which joins the prediction shap_values w/ training shap_values and distribution metrics.

Test

Following CMDs have been tested:

External Usage:

spark-rapids train
spark-rapids predict

Internal Usage:

python qualx_main.py preprocess
python qualx_main.py train
python qualx_main.py predict
python qualx_main.py shap