interpretml / DiCE

Generate Diverse Counterfactual Explanations for any machine learning model.
https://interpretml.github.io/DiCE/
MIT License
1.37k stars 188 forks source link

TypeError: expected str, bytes or os.PathLike object, not CatBoostRegressor #421

Closed TopCoder2K closed 11 months ago

TopCoder2K commented 11 months ago

Hi there!

I don't understand why dice_ml doesn't want to accept a pre-trained model (CatBoostRegressor in my case) and requires to provide the path to the model. The code is

import dice_ml

d = dice_ml.Data(
    features={
        'season': ['spring', 'summer', 'fall', 'winter'],
        'month': list(range(1, 13)),
        'hour': list(range(0, 24)),
        'holiday': ['False', 'True'],
        'weekday': list(range(0, 7)),
        'workingday': ['False', 'True'],
        'weather': ['clear', 'misty', 'rain'],
        'temp': [-10.0, 50.0],
        'feel_temp': [-10.0, 50.0],
        'humidity': [0.0, 1.0],
        'windspeed': [0.0, 70.0]
    },
    outcome_name='bike_rentals'
)
m = dice_ml.Model(
    model_path=cb_model, backend='sklearn', model_type='regressor'
)
exp = dice_ml.Dice(d, m, method='random')

It follows one of the tutorials regarding calling the functions, so I don't understand why I get

TypeError                                 Traceback (most recent call last)
[<ipython-input-65-0c0c9dbbcb6d>](https://localhost:8080/#) in <cell line: 28>()
     26     model_path=cb_model, backend='sklearn', model_type='regressor'
     27 )
---> 28 exp = dice_ml.Dice(d, m, method='random')

3 frames
[/usr/local/lib/python3.10/dist-packages/dice_ml/dice.py](https://localhost:8080/#) in __init__(self, data_interface, model_interface, method, **kwargs)
     20         :param method: Name of the method to use for generating counterfactuals
     21         """
---> 22         self.decide_implementation_type(data_interface, model_interface, method, **kwargs)
     23 
     24     def decide_implementation_type(self, data_interface, model_interface, method, **kwargs):

[/usr/local/lib/python3.10/dist-packages/dice_ml/dice.py](https://localhost:8080/#) in decide_implementation_type(self, data_interface, model_interface, method, **kwargs)
     30                     ' since kdtree explainer needs access to entire training data')
     31         self.__class__ = decide(model_interface, method)
---> 32         self.__init__(data_interface, model_interface, **kwargs)
     33 
     34     def _generate_counterfactuals(self, query_instance, total_CFs,

[/usr/local/lib/python3.10/dist-packages/dice_ml/explainer_interfaces/dice_random.py](https://localhost:8080/#) in __init__(self, data_interface, model_interface)
     26 
     27         self.model = model_interface
---> 28         self.model.load_model()  # loading pickled trained model if applicable
     29         self.model.transformer.feed_data_params(data_interface)
     30         self.model.transformer.initialize_transform_func()

[/usr/local/lib/python3.10/dist-packages/dice_ml/model_interfaces/base_model.py](https://localhost:8080/#) in load_model(self)
     40     def load_model(self):
     41         if self.model_path != '':
---> 42             with open(self.model_path, 'rb') as filehandle:
     43                 self.model = pickle.load(filehandle)
     44 

TypeError: expected str, bytes or os.PathLike object, not CatBoostRegressor
TopCoder2K commented 11 months ago

Moreover, the documentation lacks module __init__() parameters for some reason while they are present in the source code: Screenshot from 2023-12-13 16-13-22 Screenshot from 2023-12-13 16-16-47

Is this done on purpose?

gaugup commented 11 months ago

@TopCoder2K, you could use model parameter instead of model_path.

https://github.com/interpretml/DiCE/blob/48832802c2a0822a9b203f3057e6def9e8ba0d0a/dice_ml/model_interfaces/base_model.py#L16.

That way if your model is in memory, then hopefully you should not run in model load error.

TopCoder2K commented 11 months ago

@gaugup, oh damn, so stupid error 🤦‍♂ Sorry, I was in a hurry.(

Thank you for you help! Now it's working.