MarcelRobeer / ContrastiveExplanation

Contrastive Explanation (Foil Trees), developed at TNO/Utrecht University
BSD 3-Clause "New" or "Revised" License
44 stars 6 forks source link

Tensorflow or Pytorch Support #4

Closed jrinvictus closed 4 years ago

jrinvictus commented 4 years ago

This is a great package. Do you plan to support Tensorflow or Pytorch in the future?

MarcelRobeer commented 4 years ago

It should work out of the box for Tensorflow and PyTorch if you make a wrapper around your current prediction function that converts numpy arrays (np.ndarray) to tensorflow/pytorch tensors, then makes the prediction, and converts the resulting prediction (e.g. after softmax) to a numpy array.

Example (for PyTorch):

import torch

def predict_proba(X):
    torch.tensor(X)  # convert from numpy to torch
    y = model(X)
    y = y.numpy()  # convert from torch to numpy
    return y

Note that when you are using a GPU you might also have to map to the GPU and back from the GPU when X and y are tensors.