deel-ai / xplique

👋 Xplique is a Neural Networks Explainability Toolbox
https://deel-ai.github.io/xplique
Other
639 stars 52 forks source link

[Bug]: - Cannot pickle KernelShap object #144

Closed thomasdulacatos closed 11 months ago

thomasdulacatos commented 11 months ago

Module

Attributions Methods

Current Behavior

Pickling of KernelShap doesn't work, using dill, pickle or cloudpickle. Error : TypeError: cannot pickle '_thread._local' object It's a duplicate of another ticket I created 2 weeks ago, but I wasn't clear enough / precise enough.

Expected Behavior

Pickling should work.

Version

1.1.0

Environment

- OS:Centos 7
- Python version: 3.9.10
- Tensorflow version: 2.12.0
- Packages used version:
numpy                         1.22.0
pandas                        1.2.0
dill                          0.3.7
scikit-learn                  1.2.2

Relevant log output

Traceback (most recent call last):
  File "/home/dulact/gna.py", line 37, in <module>
    main()
  File "/home/dulact/gna.py", line 36, in main
    dump1 = dill.dumps(explainer)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 278, in dumps
    dump(obj, file, protocol, byref, fmode, recurse, **kwds)#, strictio)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 250, in dump
    Pickler(file, protocol, **_kwds).dump(obj)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 418, in dump
    StockPickler.dump(self, obj)
  File "/usr/local/lib/python3.9/pickle.py", line 487, in dump
    self.save(obj)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 603, in save
    self.save_reduce(obj=obj, *rv)
  File "/usr/local/lib/python3.9/pickle.py", line 717, in save_reduce
    save(state)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 1212, in save_module_dict
    StockPickler.save_dict(pickler, obj)
  File "/usr/local/lib/python3.9/pickle.py", line 971, in save_dict
    self._batch_setitems(obj.items())
  File "/usr/local/lib/python3.9/pickle.py", line 997, in _batch_setitems
    save(v)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 603, in save
    self.save_reduce(obj=obj, *rv)
  File "/usr/local/lib/python3.9/pickle.py", line 717, in save_reduce
    save(state)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 1212, in save_module_dict
    StockPickler.save_dict(pickler, obj)
  File "/usr/local/lib/python3.9/pickle.py", line 971, in save_dict
    self._batch_setitems(obj.items())
  File "/usr/local/lib/python3.9/pickle.py", line 997, in _batch_setitems
    save(v)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 603, in save
    self.save_reduce(obj=obj, *rv)
  File "/usr/local/lib/python3.9/pickle.py", line 717, in save_reduce
    save(state)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 1212, in save_module_dict
    StockPickler.save_dict(pickler, obj)
  File "/usr/local/lib/python3.9/pickle.py", line 971, in save_dict
    self._batch_setitems(obj.items())
  File "/usr/local/lib/python3.9/pickle.py", line 997, in _batch_setitems
    save(v)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 1965, in save_function
    _save_with_postproc(pickler, (_create_function, (
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 1112, in _save_with_postproc
    pickler.save_reduce(*reduction)
  File "/usr/local/lib/python3.9/pickle.py", line 692, in save_reduce
    save(args)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self
  File "/usr/local/lib/python3.9/pickle.py", line 886, in save_tuple
    save(element)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 1965, in save_function
    _save_with_postproc(pickler, (_create_function, (
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 1112, in _save_with_postproc
    pickler.save_reduce(*reduction)
  File "/usr/local/lib/python3.9/pickle.py", line 692, in save_reduce
    save(args)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self
  File "/usr/local/lib/python3.9/pickle.py", line 886, in save_tuple
    save(element)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 603, in save
    self.save_reduce(obj=obj, *rv)
  File "/usr/local/lib/python3.9/pickle.py", line 717, in save_reduce
    save(state)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 1212, in save_module_dict
    StockPickler.save_dict(pickler, obj)
  File "/usr/local/lib/python3.9/pickle.py", line 971, in save_dict
    self._batch_setitems(obj.items())
  File "/usr/local/lib/python3.9/pickle.py", line 997, in _batch_setitems
    save(v)
  File "/opt/cloned_file/proma-one-engine/target/venv/build/cpython-3.9.18.final.0/lib/python3.9/site-packages/dill/_dill.py", line 412, in save
    StockPickler.save(self, obj, save_persistent_id)
  File "/usr/local/lib/python3.9/pickle.py", line 578, in save
    rv = reduce(self.proto)
TypeError: cannot pickle '_thread._local' object

To Reproduce

Python file :

` import numpy as np import pandas as pd import dill

from sklearn.preprocessing import MinMaxScaler from xplique.attributions import KernelShap

class SimpleModel:

def fit(self, X):
    self.scaler = MinMaxScaler(feature_range=(0, 1))
    self.scaler.fit(X)
    return self
def transform(self, X):
    X = self.scaler.transform(X)
    return X
def predict_proba(self, X):
    return self.transform(X)

def main():

model = SimpleModel()
model.fit(pd.DataFrame(np.random.randint(1,100,size=(100, 100))))

explainer = KernelShap(model = model)

explainer.explain(np.random.randint(1,100,size=(100, 100)), 
                  np.random.randint(1,100,size=(100, 100)))

model = SimpleModel()
model.fit(pd.DataFrame(np.random.randint(1,100,size=(100, 100))))
explainer = KernelShap(model = model)
dump1 = dill.dumps(explainer)

main() `

lucashervier commented 11 months ago

Hi there!

Just to be sure do you want to pickle the explainer object or the explanations ?

lucashervier commented 11 months ago

So after investigating with @dv-ai we found out the reason why the dump isn't working (and that is true for Lime too: it is because similarity_kernel and pertub_func attributes are @tf.function thus a graph is built and it cannot be serialized. So there are two solutions we can suggest:

Solution 1: Do not serialize KernelShap explainer but only your simple model and the explanations as building KernelShap once the model is trained is fast:

def main():

    model = SimpleModel()
    model.fit(pd.DataFrame(np.random.randint(1,100,size=(100, 100))))
    dump1 =  dill.dumps(model)

    explainer = KernelShap(model = model)

    explanations = explainer.explain(np.random.randint(1,100,size=(100, 100)), 
                    np.random.randint(1,100,size=(100, 100)))

    dump2 = dill.dumps(explanations)

Solution 2: You can run tf eagerly, however, that may induce a performance decrease (but that depends on your data and your model):

import tensorflow as tf

def main():

     tf.config.run_functions_eagerly(True)
    model = SimpleModel()
    model.fit(pd.DataFrame(np.random.randint(1,100,size=(100, 100))))

    explainer = KernelShap(model = model)

    explanations = explainer.explain(np.random.randint(1,100,size=(100, 100)), 
                    np.random.randint(1,100,size=(100, 100)))

    model = SimpleModel()
    model.fit(pd.DataFrame(np.random.randint(1,100,size=(100, 100))))
    explainer = KernelShap(model = model)
    dump1 = dill.dumps(explainer)    

For performance issues, we do not think to remove the @tf.function decorator. Consequently, can you tell us if you are satisfied with the solutions we gave here?

thomasdulacatos commented 11 months ago

Thanks for your solutions. I'll use the first one in the futur, It seems to be working perfectly well. Maybe you could add something in the doc for futur case as mine :) Have a nice day Thomas