marcotcr / lime

Lime: Explaining the predictions of any machine learning classifier
BSD 2-Clause "Simplified" License
11.41k stars 1.79k forks source link

OOM with Bert Tokenizer #651

Open ma-batita opened 2 years ago

ma-batita commented 2 years ago

Hello,

I am using lime to get interpretation about a classification problem. First, I am using Flaubert Tokenizer (also I tried different tokenizer and had the same problem) to transfer my text to tokens. Next I put the tokens as input to my model to get the probability with a softmax function (all of this is wrapped in a prediction method)

def predict(text):
  encoding = tokenizer(text, padding=True, truncation=True, max_length=100, return_tensors="pt",)
  outputs = module(encoding["input_ids"], encoding["attention_mask"])
  probas = F.softmax(outputs.logits, dim=1).detach().numpy()
  return probas

after, I created the explainer and all the other things thats go with ...

explainer = LimeTextExplainer(class_names=['list', 'of', 'my', 'classes']) # I have 11 classes

msg = "here is my dummy txt" # OOM with this message
exp = explainer.explain_instance(msg, predict, top_labels=1)
exp.show_in_notebook(text=msg)

The problem is if I use a short message I get my result (exemple of msg =bonjour ca va ). And if I run with a longer message I get OOM after 1min.

Can you please see I did miss something here? Thnks!!

wireless911 commented 2 years ago

exp = explainer.explain_instance(text, predict, num_features=6, top_labels=2, num_samples=3) you can modify parameters num_samples ,default :5000

ptschandl commented 1 year ago

You could batch the module() forward pass within your predict(text) function by only taking batch-sized chunks of the texts and concatenating the probas before returning. The num_samples default of 5000 gives you a single batch of 5000 samples which is almost certainly causing the OOM. (also commented here)