cdpierse / transformers-interpret

Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code.
Apache License 2.0
1.27k stars 96 forks source link

Token Classification Memory Issue #117

Open mattdeeperinsights opened 1 year ago

mattdeeperinsights commented 1 year ago

Hi there, I am using TokenClassificationExplainer which works great with small text inputs. However, for a large 512 token input my machine runs out of memory and crashes.

This is probably because of an [unavoidable?] exponential memory issue of creating attributions for all 512 subtokens for each subtoken. Normally, the local context is most likely to contribute to the classificaiton of each token, it could be useful to reduce the number of pairs we make for each token. For example, the 512th token is less related to the 10th token, than tokens 1 - 20.

Would it be possible to add in a context_window_size: Optional[int] = None somewhere (possibly here), to reduce the number of pairs, so we we only calculate attributions for the context_window_size many tokens to the left and right.

I have tried to dig into the code to add this functionality in a PR but I don't understand all of the functions as they are missing doc strings and comments (in here).