robinvanschaik / interpret-flair

A small repository to test Captum Explainable AI with a trained Flair transformers-based text classifier.
MIT License
24 stars 2 forks source link

RuntimeError: Input, output and indices must be on the current device #5

Open cpuyyp opened 3 years ago

cpuyyp commented 3 years ago

Hi Robin,

I found an error when I test with the pre-trained sentiment model from Flair. Simply loading with

classifier = TextClassifier.load('sentiment') flair_model_wrapper = ModelWrapper(classifier)

And the rest is the same as yours. I got this error when I call the function interpret_sentence.

I test and print out the device. It turns out that the variable input_ids in function interpret_sentence is on cpu. My clumsy solution is to add

input_ids = input_ids.to(device)

after line

input_ids = flair_model_wrapper.tokenizer.encode(...)

There might be other internal solutions.

BTW, this work helps a lot!

robinvanschaik commented 3 years ago

Hi @cpuyyp,

Thanks for using the repo and raising the issue. I may have overlooked it because my Macbook does not have a GPU. I will probably adjust it this weekend, and you could test it afterwards.

Kind regards

robinvanschaik commented 3 years ago

Hi @cpuyyp,

I tested that this works indeed on a GPU notebook on Google Cloud.

# Store the encoding on the GPU.
input_ids = input_ids.to(flair_model_wrapper.device)

image Not sure how to tackle CUDA memory management, e.g. if you rerun the function multiple times you might get OOM errors.

Will test this on my laptop tomorrow just to see if I will not break anything on a CPU-only machine.

krzysztoffiok commented 3 years ago

In my case your explanation helped (I had to edit the function manually as you proposed after cloning your repo, so I guess this solution is not in the repo yet?), but I still needed to state clearly before the type of flair.device (flair.device = 'cuda') before starting everything.

It's a great thing you achieved @robinvanschaik ! I was looking for a solution like this for some time now already. Thanks a lot! Question: if I am to use your code/repo, is there any research publication you would like me to cite?

robinvanschaik commented 3 years ago

In my case your explanation helped (I had to edit the function manually as you proposed after cloning your repo, so I guess this solution is not in the repo yet?), but I still needed to state clearly before the type of flair.device (flair.device = 'cuda') before starting everything.

It's a great thing you achieved @robinvanschaik ! I was looking for a solution like this for some time now already. Thanks a lot! Question: if I am to use your code/repo, is there any research publication you would like me to cite?

Hi @krzysztoffiok,

I am glad to hear that you find this repo useful! :)

You are right. For some reason I never got around to actually pushing this to the master. I guess life got the best of me.

It is possible that I might have some free time soon. In the meantime I would definitely welcome a pull request that will solve this issue!

Regarding citing this repo; I am not affiliated with any academic institution, nor do I write any (academic) papers.

In that regard this was a hobby project. I am definitely standing on the shoulders of the CAPTUM & Flair teams, but I appreciate the fact that you are checking in for citing this repo. :)

Is there a way I could facilitate you in making sure this work is properly cited?

Then I will add a snippet to the markdown file on the front page.

Cheers

krzysztoffiok commented 3 years ago

Hi @robinvanschaik ,

Thank you for a very quick merge :) I have created another pull request to force the user to clearly state the device they will be using. For me it helped.

robinvanschaik commented 3 years ago

Hi @krzysztoffiok,

Thanks for contributing with your pull requests! Keep them coming.

Are you willing to reflect your changes in the tutorial in the readme.md as well? You might have the snippet at hand.

I believe that we can close this issue after this has been updated.

Afterwards I can create a new release as soon as this has been updated.

The Generated DOI will reflect the new release as well, which should help with citing the code.

Cheers.

krzysztoffiok commented 3 years ago

@robinvanschaik

OK I will do that.

BTW, I have also noticed that, presumably for some slightly older model versions (this is my guess of the reason), there is a new error (see below). It happened with other models that I've fine-tuned ~9 months ago and not only Albert but also BERT and RoBERTa fine-tuned at the same time.

If I'm correct that this is a package version issue, I guess interpret-flair should clearly state which version of huggingface transformers and flair and captum to use.

AttributeError Traceback (most recent call last)

in 6 n_steps=500, 7 estimation_method="gausslegendre", ----> 8 internal_batch_size=3) ~/env/flair06/respect/data/models/respect_5k_final_respect_values_0/interpretation_package/interpret_flair.py in interpret_sentence(flair_model_wrapper, lig, sentence, target_label, visualization_list, n_steps, estimation_method, internal_batch_size) 62 # Thus we calculate the softmax afterwards. 63 # For now, I take the first dimension and run this sentence, per sentence. ---> 64 model_outputs = flair_model_wrapper(input_ids) 65 66 softmax = torch.nn.functional.softmax(model_outputs[0], dim=0) ~/env/flair06/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(), ~/env/flair06/respect/data/models/respect_5k_final_respect_values_0/interpretation_package/flair_model_wrapper.py in forward(self, input_ids) 44 # Run the input embeddings through all the layers. 45 # Return the hidden states of the model. ---> 46 hidden_states = self.model(input_ids=input_ids)[-1] 47 48 # BERT has an initial CLS token. ~/env/flair06/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(), ~/env/flair06/lib/python3.7/site-packages/transformers/modeling_albert.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict) 656 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 657 ) --> 658 return_dict = return_dict if return_dict is not None else self.config.use_return_dict 659 660 if input_ids is not None and inputs_embeds is not None: ~/env/flair06/lib/python3.7/site-packages/transformers/configuration_utils.py in use_return_dict(self) 221 """ 222 # If torchscript is set, force `return_dict=False` to avoid jit errors --> 223 return self.return_dict and not self.torchscript 224 225 @property AttributeError: 'AlbertConfig' object has no attribute 'return_dict'
robinvanschaik commented 3 years ago

@krzysztoffiok You are definitely right.

I should have added a requirements.txt to the repository with pinned versions.

This would make it more reproducible.

krzysztoffiok commented 3 years ago

@robinvanschaik do you think we could meet online about the functionality of interpret-flair package? I'm not that much familiar with practical use of IG method and its various parameters, so if you had the time to clarify some aspects that would greatly help me to get proper results. I have tried for a while and it didn't work as straightforward as I would expect.

Please contact me at krzysztof.fiok at gmail.com if you agree.