cdpierse / transformers-interpret

Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code.
Apache License 2.0
1.29k stars 97 forks source link

Multirprocessing - AttributeError: Can't pickle local object #64

Closed subhamkhemka closed 3 years ago

subhamkhemka commented 3 years ago

Hi

Running the below code, getting an error when using multiprocessing. Please help

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import ZeroShotClassificationExplainer

tokenizer = AutoTokenizer.from_pretrained("typeform/distilbert-base-uncased-mnli") #facebook/bart-large-mnli, typeform/distilbert-base-uncased-mnli
model = AutoModelForSequenceClassification.from_pretrained("typeform/distilbert-base-uncased-mnli")
model.cuda()
zero_shot_explainer = ZeroShotClassificationExplainer(model, tokenizer)

def get_att_label(tags,zero_shot_explainer,sentence):
    word_attributions = zero_shot_explainer(
    sentence,
    labels = tags)
    return zero_shot_explainer.predicted_label, word_attributions[zero_shot_explainer.predicted_label]

from torch.multiprocessing import Pool, Process, set_start_method
from functools import partial
from tqdm import tqdm
try:
     set_start_method('spawn')
except RuntimeError:
    pass

if __name__ == '__main__': 
    p = Pool(processes=5)
    get_att_label_fixed_params = partial(get_att_label, tags=tag_values, zero_shot_explainer=zero_shot_explainer)
    predictions = p.map(get_att_label_fixed_params,test_lst)
    p.close()
    p.terminate()
    p.join()

Error -


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-16-91ebdc8201d7> in <module>
      3     get_att_label_fixed_params = partial(get_att_label, tags=tag_values, zero_shot_explainer=zero_shot_explainer)
----> 4     predictions = p.map(get_att_label_fixed_params,test_lst)
      5     p.close()

~/anaconda3/envs/pytorch_p36/lib/python3.6/multiprocessing/pool.py in map(self, func, iterable, chunksize)
    264         in a list that is returned.
    265         '''
--> 266         return self._map_async(func, iterable, mapstar, chunksize).get()
    267 
    268     def starmap(self, func, iterable, chunksize=None):

~/anaconda3/envs/pytorch_p36/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
    642             return self._value
    643         else:
--> 644             raise self._value
    645 
    646     def _set(self, i, obj):

~/anaconda3/envs/pytorch_p36/lib/python3.6/multiprocessing/pool.py in _handle_tasks(taskqueue, put, outqueue, pool, cache)
    422                         break
    423                     try:
--> 424                         put(task)
    425                     except Exception as e:
    426                         job, idx = task[:2]

~/anaconda3/envs/pytorch_p36/lib/python3.6/multiprocessing/connection.py in send(self, obj)
    204         self._check_closed()
    205         self._check_writable()
--> 206         self._send_bytes(_ForkingPickler.dumps(obj))
    207 
    208     def recv_bytes(self, maxlength=None):

~/anaconda3/envs/pytorch_p36/lib/python3.6/multiprocessing/reduction.py in dumps(cls, obj, protocol)
     49     def dumps(cls, obj, protocol=None):
     50         buf = io.BytesIO()
---> 51         cls(buf, protocol).dump(obj)
     52         return buf.getbuffer()
     53 

AttributeError: Can't pickle local object 'LayerIntegratedGradients.attribute.<locals>.gradient_func'

Can you please assist ?

Regards, Subham

cdpierse commented 3 years ago

Hi @subhamkhemka,

I'm not sure I'm able to help you with this issue, it seems related to the multiprocessing package and its attempts to pickle a particular gradient function. All I know is that strange things can happen when pickling such objects especially in a multiprocessing environment, sorry I couldn't be of more help.

subhamkhemka commented 3 years ago

no worries, thanks for taking a look @cdpierse

I need to run this for a large dataset of about 1M sentences, do you have any auggestions on how I could speed up the process ?

Thanks

cdpierse commented 3 years ago

@subhamkhemka 1M sentences is definitely never going to be a fast process to calculate attributions because of how the attributions are calculated but if you wanted to get them calculated as fast as possible I would:

The default value for n_steps is 50 so reducing this will definitely speed up the attribution calculation but it will almost certainly reduce the accuracy or quality of the attributions so I'd be careful of reducing it all the way down to 1. If I were you I would do some tests and assess the quality of attributions at different values of n_steps and go from there.

To use n_steps make sure you are using the latest version of the package 0.5.2

Good luck with this :-)

subhamkhemka commented 3 years ago

Thanks, will definitely try this out. @cdpierse

I am running these in a loop currently, is batch inference supported or part of roadmap ?

cdpierse commented 3 years ago

@cdpierse It's not something I have planned on the current roadmap. I don't think technically it is too much work as it would just be a wrapper around the existing explainers that runs over n inputs. But most people will probably be fine using a loop like you are.