PAIR-code / lit

The Learning Interpretability Tool: Interactively analyze ML models to understand their behavior in an extensible and framework agnostic interface.
https://pair-code.github.io/lit
Apache License 2.0
3.48k stars 355 forks source link

Unknown error when running TCAV on a multiclass dataset #793

Closed terraformmachine closed 1 year ago

terraformmachine commented 2 years ago

I'm getting an unknown error when running TCAV on a multiclass dataset.

screenshot

Steps to reproduce:

  1. Open Colab Notebook: https://colab.research.google.com/drive/1hFwjvr_JLJuR5KmJgmkOSp7VUI1iAlPN?usp=sharing
  2. Run the notebook to render the widget
  3. Search text for "she|her" and Create a Slice called "female"
  4. Go to the TCAV tab and select "female" slice
  5. Run TCAV

Unknown Error :scream:

Dataset class:

class EmotionData(lit_dataset.Dataset):
  # emotion dataset:
  ## url:   https://huggingface.co/datasets/emotion
  ## text:  a string feature.
  ## label: a classification label, with possible values including: sadness (0), joy (1), love (2), anger (3), fear (4), surprise (5).

  LABELS = ['0', '1', '2', '3', '4', '5']

  def __init__(self, path):
    df = pd.read_csv(path)
    self._examples = [{
      'text': row['text'],
      'label': row['label']
    } for _, row in df.iterrows()]

  def spec(self):
    return {
      'text': lit_types.TextSegment(),
      'label': lit_types.CategoryLabel(vocab=self.LABELS),
    }

Model class:

class EmotionModel(model.Model):

    LABELS = ["0", "1", "2", "3", "4", "5"]

    def __init__(self, model_path=None, **kw):
        self._model = transformers.AutoModelForSequenceClassification.from_pretrained(
            model_path, output_hidden_states=True, output_attentions=True
        )
        self._tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)

    def predict_minibatch(self, inputs):
        texts = [x["text"] for x in inputs]
        tokenized_input = self._tokenizer.batch_encode_plus(
            texts,
            return_tensors="pt",
            add_special_tokens=True,
            max_length=128,
            padding="longest",
            truncation=True,
        )

        if torch.cuda.is_available():
            self._model.cuda()
            for tensor in tokenized_input:
                tokenized_input[tensor] = tokenized_input[tensor].cuda()

        outputs = self._model(**tokenized_input)

        batched_outputs = {
            "probas": torch.nn.functional.softmax(outputs.logits, dim=-1),
            "input_ids": tokenized_input["input_ids"],
            "ntok": torch.sum(tokenized_input["attention_mask"], dim=1),
            "cls_emb": outputs.hidden_states[-1][:, 0],
        }

        scalar_pred_for_gradients = torch.max(
            batched_outputs["probas"], dim=1, keepdim=False, out=None
        )[0]

        arg_max = torch.argmax(batched_outputs["probas"], axis=-1).numpy()
        grad_classes = [
            ex.get("grad_class", arg_max[i]) for (i, ex) in enumerate(inputs)
        ]
        grad_classes = [
            self.config.labels.index(label) if isinstance(label, str) else label
            for label in grad_classes
        ]
        batched_outputs["grad_class"] = torch.tensor(grad_classes)

        batched_outputs["input_emb_grad"] = torch.autograd.grad(
            scalar_pred_for_gradients,
            outputs.hidden_states[0],
            grad_outputs=torch.ones_like(scalar_pred_for_gradients),
        )[0]

        for i, layer_attention in enumerate(outputs.attentions):
            batched_outputs[f"layer_{i}/attention"] = layer_attention

        detached_outputs = {
            k: v.cpu().detach().numpy() for k, v in batched_outputs.items()
        }

        for output in utils.unbatch_preds(detached_outputs):
            ntok = output.pop("ntok")
            output["tokens"] = self._tokenizer.convert_ids_to_tokens(
                output.pop("input_ids")[:ntok]
            )
            output["token_grad_sentence"] = output["input_emb_grad"][:ntok]

            output["cls_grad"] = output["input_emb_grad"][0]

            for key in output:
                if not re.match(r"layer_(\d+)/attention", key):
                    continue
                output[key] = output[key][:, :ntok, :ntok].transpose((0, 2, 1))
                output[key] = output[key].copy()

            yield output

    def input_spec(self):
        return {
            "text": lit_types.TextSegment(),
            "label": lit_types.CategoryLabel(vocab=self.LABELS, required=False),
            # "input_embs": lit_types.TokenEmbeddings(align="tokens", required=False),
            # "grad_class": lit_types.CategoryLabel(vocab=self.LABELS, required=False),
        }

    def output_spec(self):
        ret = {
            "tokens": lit_types.Tokens(),
            "probas": lit_types.MulticlassPreds(vocab=self.LABELS, parent="label"),
            "cls_emb": lit_types.Embeddings(),
            "token_grad_sentence": lit_types.TokenGradients(align="tokens"),
            "cls_grad": lit_types.Gradients(
                grad_for="cls_emb", grad_target_field_key="grad_class"
            ),
        }
        for i in range(self._model.config.num_hidden_layers):
            ret[f"layer_{i}/attention"] = lit_types.AttentionHeads(
                align_in="tokens", align_out="tokens"
            )
        return ret
jameswex commented 2 years ago

@terraformmachine thanks for the colab with the reproduction of the issue!

One issue I noticed is that when you load your dataset, your 'label' field for each example is an integer for the label class ID, but you want it to be a string from your vocab for use in LIT. In the future, we are looking to add dataset and model field validation to catch and indicate these types of issues on first launch instead of having unexpected issues when using LIT.

If you change your setting of 'label' in the dataset str(row['label'], then LIT will correctly see the ground truth labels. Then the metrics and classification results modules will display correctly, as opposed to some issues they had.

But, I don't think that actually fixes your TCAV issue. To debug that, I am currently running the TCAV interpreter directly in the colab as opposed to through the UI. That way, I can see on what line of code the TCAV failed. Here is the code I use for that:

from lit_nlp.components import tcav
from lit_nlp.lib import caching
from lit_nlp.api.dataset import IndexedDataset

# LIT under the covers wraps the dataset in IndexedDataset and the model in CachingModelWrapper, so doing that here in order to use them in the TCAV interpreter.
indexed_datasets = IndexedDataset.index_all(datasets, caching.input_hash)
cached_model = caching.CachingModelWrapper(models["distilbert-base-uncased-emotion"], "distilbert-base-uncased-emotion")

# Run TCAV
ids = ["ab7570637ea93bafe16a910cbb09b798", "f09151abc3763045e779602009fc1428", "a7796da2aca7a74702762877b9506e0b", "b76b3adaa5b65f2cf0072cdd15b38103",
       "e1c57d1a02d7f2e40225758bce7cfa9f", "be9b032d394de752b85481499168f8aa" , "ded6753ba0e8021ba490a36b082e9971",
       "125065a3ebb6fa7cba26fbe048c53532", "a133362931f8837b9164da8cd5dde98b", "7e19494ccea7be91e1aa818192bbf763", "b5152500f9b9fe595e7c183605c21132"]

config = {
    'class_to_explain': "0",
    'concept_set_ids': ids,
    'dataset_name': "emotion",
    'grad_layer': "cls_grad",
}
t = tcav.TCAV()
t.run_with_metadata(indexed_datasets["emotion"].indexed_examples, cached_model, indexed_datasets["emotion"], config=config)
jameswex commented 2 years ago

I was able to get TCAV working for your model once I also updated the grad_class output from your predict method to also return the string of the class instead of the integer index. I also uncommented-out the optional input of grad_class in the model's input spec.

Please let me know if this works for you.

terraformmachine commented 2 years ago

@jameswex that fixed it! thank you very much for your help :smiley: