interpretml / DiCE

Generate Diverse Counterfactual Explanations for any machine learning model.
https://interpretml.github.io/DiCE/
MIT License
1.33k stars 184 forks source link

AttributeError: 'Tensor' object has no attribute 'numpy' #319

Open mozolcer opened 2 years ago

mozolcer commented 2 years ago

I am getting this error when I try to run the code here : https://github.com/interpretml/DiCE tensorflow vs: 2.9.1 dice-ml vs: 0.8

My full code:

import os
import numpy as np
import pandas as pd
import dice_ml
import tensorflow as tf
from dice_ml.utils import helpers # helper functions
from keras import backend as K

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.config.run_functions_eagerly(True)

import dice_ml
from dice_ml.utils import helpers # helper functions
# Dataset for training an ML model
d = dice_ml.Data(dataframe=helpers.load_adult_income_dataset(),
                 continuous_features=['age', 'hours_per_week'],
                 outcome_name='income')
# Pre-trained ML model
backend = 'TF'+tf.__version__[0] # TF2 in your case
ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
m = dice_ml.Model(model_path= ML_modelpath, backend=backend)
# DiCE explanation instance
exp = dice_ml.Dice(d,m)

query_instance = {'age':22,
    'workclass':'Private',
    'education':'HS-grad',
    'marital_status':'Single',
    'occupation':'Service',
    'race': 'White',
    'gender':'Female',
    'hours_per_week': 45}

dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=4, desired_class="opposite")

`

The error traceback: `

[/usr/local/lib/python3.7/dist-packages/dice_ml/explainer_interfaces/dice_tensorflow2.py](https://localhost:8080/#) in generate_counterfactuals(self, query_instance, total_CFs, desired_class, proximity_weight, diversity_weight, categorical_penalty, algorithm, features_to_vary, permitted_range, yloss_type, diversity_loss_type, feature_weights, optimizer, learning_rate, min_iter, max_iter, project_iter, loss_diff_thres, loss_converge_maxiter, verbose, init_near_query_instance, tie_random, stopping_threshold, posthoc_sparsity_param, posthoc_sparsity_algorithm)
    134                                       loss_diff_thres, loss_converge_maxiter, verbose,
    135                                       init_near_query_instance, tie_random, stopping_threshold,
--> 136                                       posthoc_sparsity_param, posthoc_sparsity_algorithm)
    137 
    138         counterfactual_explanations = exp.CounterfactualExamples(

[/usr/local/lib/python3.7/dist-packages/dice_ml/explainer_interfaces/dice_tensorflow2.py](https://localhost:8080/#) in find_counterfactuals(self, query_instance, desired_class, optimizer, learning_rate, min_iter, max_iter, project_iter, loss_diff_thres, loss_converge_maxiter, verbose, init_near_query_instance, tie_random, stopping_threshold, posthoc_sparsity_param, posthoc_sparsity_algorithm)
    431 
    432         # find the predicted value of query_instance
--> 433         test_pred = self.predict_fn(tf.constant(query_instance, dtype=tf.float32))[0][0]
    434         if desired_class == "opposite":
    435             desired_class = 1.0 - round(test_pred)

[/usr/local/lib/python3.7/dist-packages/dice_ml/explainer_interfaces/dice_tensorflow2.py](https://localhost:8080/#) in predict_fn(self, input_instance)
    148     def predict_fn(self, input_instance):
    149         """prediction function"""
--> 150         temp_preds = self.model.get_output(input_instance).numpy()
    151         return np.array([preds[(self.num_output_nodes-1):] for preds in temp_preds], dtype=np.float32)
    152 

[/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py](https://localhost:8080/#) in __getattr__(self, name)
    444         np_config.enable_numpy_behavior()
    445       """)
--> 446     self.__getattribute__(name)
    447 
    448   @staticmethod

AttributeError: 'Tensor' object has no attribute 'numpy'`
amit-sharma commented 1 year ago

@mozolcer This might be because of the tf v1 compat code. Can you try the same code without the v1 compat lines? This error is often associated with lack of eager execution.

amit-sharma commented 1 year ago

you may also try the latest version of DiCE, v0.9, that has updates to the tensorflow code and should resolve this bug.