PAIR-code / what-if-tool

Source code/webpage/demos for the What-If Tool
https://pair-code.github.io/what-if-tool
Apache License 2.0
907 stars 169 forks source link

TensorBoard does not show attribution for WIT #152

Closed Lodewic closed 3 years ago

Lodewic commented 3 years ago

Using tensorboard --whatif-use-unsafe-custom-prediction my_predict.py with a custom_predict_fn() that returns attribution values, these attribution values fail to show up in What-If-Tool within TensorBoard.

The COMPAS demo shows SHAP attribution values in the Colab notebook, https://colab.research.google.com/github/PAIR-code/what-if-tool/blob/master/WIT_COMPAS_with_SHAP.ipynb.

I have confirmed that this same example works with no problem in Jupyter Lab and Jupyter Notebook. But I need it to work in TensorBoard as well.

Using the same COMPAS demo I have split the custom_predict_fn() into a separate file and am able to make predictions with the WIT embedded in tensorboard. Using tensorboard --whatif-use-unsafe-custom-prediction my_predict.py. I have no start-up errors, predictions and features show up in the WIT, but no attribution values. I am sure these attribution values are returned by custom_predict_fn().

What can I do to to help view attributions in TensorBoard? I need this because on Azure DataBricks we can't use the notebook WitWidget but are able to use TensorBoard.

jameswex commented 3 years ago

Thanks for the bug report. I just ran a custom predict fn in TB myself and verified that we're not correctly passing attributions back to the frontend with TB custom prediction functions.

I'll need to fix this and put out a new patch version of the pip package.

jameswex commented 3 years ago

@Lodewic I have a PR that fixes this issue. You can wait for it to be submitted and for a new pip release of the tensorboard-plugin-wit package, or you can pull in this PR and build the pip package locally with "bazel run tensorboard_plugin_wit/pip_package:build_pip_package" and install it from the locally-built package.

If your willing to pull the PR #153 and build/test it locally with your custom predict fn, it'll help me verify it works outside of just my tests.

For reference, here is what my test predict fn file looks like:

import random

# The function name "custom_predict_fn" must be exact.
def custom_predict_fn(examples, serving_bundle):
  # Examples are a list of TFRecord objects, each object contains the features of each point.
  # serving_bundle is a dictionary that contains the setup information provided to the tool,
  # such as server address, model name, model version, etc.

  number_of_examples = len(examples)
  results = []
  attrs = []
  for _ in range(number_of_examples):
    score = random.random()
    results.append([score, 1 - score]) # For binary classification
    attrs.append({'Age': random.random(),
      'Capital-Loss': .4})
  return {'predictions': results, 'attributions': attrs}
Lodewic commented 3 years ago

Hi @jameswex , thank you so much for the quick response and fix. I look forward to the next release!

Bazel is new to me and I'm on a Windows machine, so I tried to test your PR and building it in a bazel:0.27.0 docker container but no luck so far. The build seems to work, with many warnings, but I find no output in the tmp/wit-widget/ folder to get the wheel from.

Likely my local issue.. I'll try again today!