akionakamura / pr-tensorboard-keras-example

Example code using PR Curve TensorBoard plugin with Keras
GNU General Public License v3.0
20 stars 7 forks source link

Precision-Recall curve with Keras

A blog post describing the work here can be found on my Medium profile.

TensorBoard is a suite of visualizations for inspecting and understanding your TensorFlow models and runs. They recently released of a "consistent set of APIs that allows developers to add custom visualization plugins to TensorBoard". There are already several plugins available.

Keras "is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, or Theano". When using the TensorFlow backend, they typically support the TensorBoard callback, to take advantage for its visualizations.

Keras' TensorBoard callback, however, still do not support the plugins. I recently wanted to use the Precision-Recall curve plugin (pr_curve) to see how my binary classification problem was doing. I ended up writing an extension of the callback supporting it. Although it is only a partial support (lacks usage of weights, for example), hopefully this will help anyone else in need of similar code, since I've found very little material about it around the web.

Run the Example:

Make sure you have pipenv installed, so we can have the right dependencies to reproduce this. Having installed it, run:

pipenv shell
python3 example.py
tensorboard --logdir=./logs

The script will download a small dataset to run the example on real data. The data is from UC Irvine Machine Learning Repository, specifically the Breast Cancer Wisconsin (Diagnostic) Data Set.

Follows a screenshot we the result:

Example screen shot