cdpierse / transformers-interpret

Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code.
Apache License 2.0
1.3k stars 97 forks source link

Support for Reformer #139

Open jayeshp0 opened 1 year ago

jayeshp0 commented 1 year ago

Even though ReformerForSequenceClassification follows the "{MODEL_NAME}ForSequenceClassification" pattern as indicated in the documentation, I end up with a RuntimeError as follows. Since it's so simple to use the Explainers, I'm assuming Reformer models aren't supported yet.

Here's the last frame in the trace in the Explainer code: File ~/.local/lib/python3.10/site-packages/transformers_interpret/explainer.py:197, in BaseExplainer._get_preds(self, input_ids, token_type_ids, position_ids, attention_mask) 194 return preds 196 elif self.accepts_position_ids: --> 197 preds = self.model( 198 input_ids=input_ids, 199 position_ids=position_ids, 200 attention_mask=attention_mask, 201 ) 203 return preds 204 elif self.accepts_token_type_ids:

Here's the last frame where error occured: File ~/.local/lib/python3.10/site-packages/transformers/models/reformer/modeling_reformer.py:2174, in ReformerModel._pad_to_mult_of_chunk_length(self, input_ids, inputs_embeds, attention_mask, position_ids, input_shape, padding_length, padded_seq_length, device) 2172 # Pad position ids if given 2173 if position_ids is not None: -> 2174 padded_position_ids = torch.arange(input_shape[-1], padded_seq_length, dtype=torch.long, device=device) 2175 padded_position_ids = position_ids.unsqueeze(0).expand(input_shape[0], padding_length) 2176 position_ids = torch.cat([position_ids, padded_position_ids], dim=-1)

RuntimeError: upper bound and larger bound inconsistent with step sign

I received the exact same error and stack trace with both SequenceClassificationExplainer and MultiLabelClassificationExplainer which I tried because I am trying to do multi-label text classification.

Please let me know if you need the full trace or any other information. Thank you.