PAIR-code / lit

The Learning Interpretability Tool: Interactively analyze ML models to understand their behavior in an extensible and framework agnostic interface.
https://pair-code.github.io/lit
Apache License 2.0
3.46k stars 352 forks source link

How to visulaize attentions for Hugging face custom models #676

Open pratikchhapolika opened 2 years ago

pratikchhapolika commented 2 years ago

Here is the code:

from lit_nlp.api import types as lit_types
from lit_nlp.examples.datasets import glue
import tensorflow_datasets as tfds
# https://github.com/PAIR-code/lit/wiki/api.md#adding-models-and-data

import sys
from absl import app
from absl import flags
from absl import logging
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import model as lit_model
from lit_nlp.lib import utils
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types
from transformers import BertTokenizer, BertForSequenceClassification
import pandas as pd
import torch
import transformers

df = pd.read_excel("data",sheet_name='master_data')
print(df.shape)
df = df[df['train'] == 1]
df = df.head(2)
df = df[['UTTERANCE','se']]
df['se'] = df['se'].astype(int)
print(df.head(2))

# sentences = ["He is an uninvited guest.", "The host of the party didn't sent him the invite."]

# tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
# model=BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2,output_hidden_states=True,output_attentions=True,return_dict=True)
# encoded_input=tokenizer.batch_encode_plus(sentences,return_tensors="pt",add_special_tokens=True,max_length=512,padding="longest",truncation="longest_first")
# # print(encoded_input)
# out: transformers.modeling_outputs.SequenceClassifierOutput = model(**encoded_input)
# print(out.attentions)

def load_tfds(*args, do_sort=True, **kw):
    """Load from TFDS, with optional sorting."""
    # Materialize to NumPy arrays.
    # This also ensures compatibility with TF1.x non-eager mode, which doesn't
    # support direct iteration over a tf.data.Dataset.

    # ds = tfds.load('glue/sst2', split='train', shuffle_files=True,download=True)
    ret = df.values.tolist()
    print(ret)
    # if do_sort:
    #     # Recover original order, as if you loaded from a TSV file.
    #     ret.sort(key=lambda ex: ex['idx'])
    return ret

class SST2Data(lit_dataset.Dataset):
    """Stanford Sentiment Treebank, binary version (SST-2).
    See https://www.tensorflow.org/datasets/catalog/glue#gluesst2.
    """

    LABELS = ['0', '1']

    def __init__(self, data):
        self._examples = []
        for ex in load_tfds(df):
            self._examples.append({'sentence': ex[0],'label': self.LABELS[ex[1]], })

        print(self._examples)

    def spec(self):
        return {
            'sentence': lit_types.TextSegment(),
            'label': lit_types.CategoryLabel(vocab=self.LABELS)
        }

FLAGS = flags.FLAGS

FLAGS.set_default("development_demo", True)

flags.DEFINE_string(
    "model_path",
    "https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz",
    "Path to trained model, in standard transformers format, e.g. as "
    "saved by model.save_pretrained() and tokenizer.save_pretrained()")

def _from_pretrained(cls, *args, **kw):
    """Load a transformers model in PyTorch, with fallback to TF2/Keras weights."""
    try:
        return cls.from_pretrained(*args, **kw)
    except OSError as e:
        logging.warning("Caught OSError loading model: %s", e)
        logging.warning("Re-trying to convert from TensorFlow checkpoint (from_tf=True)")
        return cls.from_pretrained(*args, from_tf=True, **kw)

class SimpleSEModel(lit_model.Model):
    """Simple SE classification model."""

    @property
    def num_layers(self):
        return self.model.config.num_hidden_layers
    @property
    def max_seq_length(self):
        return self.model.config.max_position_embeddings

    LABELS = ["0", "1"]  # negative, positive

    def __init__(self, model_name_or_path):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # This is a just a regular PyTorch model.
        self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2,output_hidden_states=True,output_attentions=True)
        self.model.eval()

    ##
    # LIT API implementation
    def max_minibatch_size(self):
        # This tells lit_model.Model.predict() how to batch inputs to
        # predict_minibatch().
        # Alternately, you can just override predict() and handle batching yourself.
        return 16

    def predict_minibatch(self, inputs):
        # Preprocess to ids and masks, and make the input batch.
        encoded_input = self.tokenizer.batch_encode_plus([ex["sentence"] for ex in inputs],return_tensors="pt",add_special_tokens=True,max_length=512,padding="longest",truncation="longest_first")

        # Check and send to cuda (GPU) if available
        if torch.cuda.is_available():
            self.model.cuda()
            for tensor in encoded_input:
                encoded_input[tensor] = encoded_input[tensor].cuda()
        # Run a forward pass.
        with torch.no_grad():  # remove this if you need gradients.
            out: transformers.modeling_outputs.SequenceClassifierOutput = self.model(**encoded_input)
            unused_attentions = out.attentions
            # print(unused_attentions)
            # print(type(unused_attentions))

        # Post-process outputs.
        batched_outputs = {
            "probas": torch.nn.functional.softmax(out.logits, dim=-1).tolist(),
            "input_ids": encoded_input["input_ids"],
            "ntok": torch.sum(encoded_input["attention_mask"], dim=1).tolist(),
            "cls_emb": out.hidden_states[-1][:, 0].tolist(),  # last layer, first token
        }

        for i in range(len(unused_attentions)):
            batched_outputs[f"layer_{i:d}_attention"] = unused_attentions[i].detach().numpy()

        # unbatched_outputs = utils.unbatch_preds(batched_outputs)
        # Return as NumPy for further processing.

        # for k, v in batched_outputs.items():
        #     print("batched_output")
        #     print(v)
        #     print(type(v))
        detached_outputs = {k: v for k, v in batched_outputs.items()}
        print("detached_outputs")
        print(detached_outputs)
        # Unbatch outputs so we get one record per input example.
        for output in utils.unbatch_preds(detached_outputs):
            ntok = output.pop("ntok")
            output["tokens"] = self.tokenizer.convert_ids_to_tokens(output.pop("input_ids")[1:ntok - 1])
            yield output

    def input_spec(self) -> lit_types.Spec:
        return {
            "sentence": lit_types.TextSegment(),
            "label": lit_types.CategoryLabel(vocab=self.LABELS, required=False),
            "tokens": lit_types.Tokens(parent='text', required=False)
        }

    def output_spec(self) -> lit_types.Spec:
        spec = {
            "tokens": lit_types.Tokens(parent="text"),
            "probas": lit_types.MulticlassPreds(parent="label", vocab=self.LABELS,null_idx=0),
            "cls_emb": lit_types.Embeddings()}
        for i in range(self.num_layers):
            spec[f"layer_{i:d}_attention"] = lit_types.AttentionHeads(align_in="tokens", align_out="tokens")
        return spec

def get_wsgi_app():
    """Returns a LitApp instance for consumption by gunicorn."""
    FLAGS.set_default("server_type", "external")
    FLAGS.set_default("demo_mode", True)
    # Parse flags without calling app.run(main), to avoid conflict with
    # gunicorn command line flags.
    unused = flags.FLAGS(sys.argv, known_only=True)
    return main(unused)

def main(_):
    # Normally path is a directory; if it's an archive file, download and
    # extract to the transformers cache.
    model_path = FLAGS.model_path
    if model_path.endswith(".tar.gz"):
        model_path = transformers.file_utils.cached_path(model_path, extract_compressed_file=True)

    # Load the model we defined above.
    models = {"sst": SimpleSEModel(model_path)}
    # Load SST-2 validation set from TFDS.
    datasets = {"sst_dev": SST2Data(df)}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    return lit_demo.serve()

if __name__ == "__main__":
    app.run(main)

In the tool it gives the following error:

`zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
I0323 23:11:39.272578 4354911744 _internal.py:113] 127.0.0.1 - - [23/Mar/2022 23:11:39] "POST /get_interpretations?model=sst&dataset_name=sst_dev&interpreter=metrics HTTP/1.1" 200 -
I0323 23:12:00.909870 4354911744 app.py:138] 1 of 1 inputs sent as IDs; reconstituting from dataset 'sst_dev'
I0323 23:12:00.910048 4354911744 caching.py:202] CachingModelWrapper 'sst': misses (dataset=sst_dev): []
I0323 23:12:00.910130 4354911744 caching.py:204] CachingModelWrapper 'sst': 0 misses out of 1 inputs
I0323 23:12:00.910184 4354911744 caching.py:210] Prepared 0 inputs for model
I0323 23:12:00.910248 4354911744 caching.py:212] Received 0 predictions from model
/Users/pchhapolika/opt/anaconda3/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1245: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
I0323 23:12:00.913743 4354911744 _internal.py:113] 127.0.0.1 - - [23/Mar/2022 23:12:00] "POST /get_interpretations?model=sst&dataset_name=sst_dev&interpreter=metrics HTTP/1.1" 200 -
I0323 23:12:00.914530 4354911744 app.py:138] 1 of 1 inputs sent as IDs; reconstituting from dataset 'sst_dev'
I0323 23:12:00.914641 4354911744 caching.py:202] CachingModelWrapper 'sst': misses (dataset=sst_dev): []
I0323 23:12:00.914699 4354911744 caching.py:204] CachingModelWrapper 'sst': 0 misses out of 1 inputs
I0323 23:12:00.914747 4354911744 caching.py:210] Prepared 0 inputs for model
I0323 23:12:00.914806 4354911744 caching.py:212] Received 0 predictions from model
I0323 23:12:00.914860 4354911744 app.py:187] Requested types: ['MulticlassPreds']
I0323 23:12:00.914927 4354911744 app.py:197] Will return keys: {'probas'}
I0323 23:12:00.915141 4354911744 _internal.py:113] 127.0.0.1 - - [23/Mar/2022 23:12:00] "POST /get_preds?model=sst&dataset_name=sst_dev&requested_types=MulticlassPreds HTTP/1.1" 200 -
I0323 23:12:00.925801 4354911744 app.py:138] 1 of 1 inputs sent as IDs; reconstituting from dataset 'sst_dev'
I0323 23:12:00.925930 4354911744 caching.py:202] CachingModelWrapper 'sst': misses (dataset=sst_dev): []
I0323 23:12:00.925989 4354911744 caching.py:204] CachingModelWrapper 'sst': 0 misses out of 1 inputs
I0323 23:12:00.926041 4354911744 caching.py:210] Prepared 0 inputs for model
I0323 23:12:00.926104 4354911744 caching.py:212] Received 0 predictions from model
I0323 23:12:00.926160 4354911744 app.py:187] Requested types: ['MulticlassPreds']
I0323 23:12:00.926231 4354911744 app.py:197] Will return keys: {'probas'}
I0323 23:12:00.926454 4354911744 _internal.py:113] 127.0.0.1 - - [23/Mar/2022 23:12:00] "POST /get_preds?model=sst&dataset_name=sst_dev&requested_types=MulticlassPreds HTTP/1.1" 200 -
I0323 23:12:00.928698 4354911744 app.py:138] 1 of 1 inputs sent as IDs; reconstituting from dataset 'sst_dev'
I0323 23:12:00.928821 4354911744 caching.py:202] CachingModelWrapper 'sst': misses (dataset=sst_dev): []
I0323 23:12:00.928880 4354911744 caching.py:204] CachingModelWrapper 'sst': 0 misses out of 1 inputs
I0323 23:12:00.928930 4354911744 caching.py:210] Prepared 0 inputs for model
I0323 23:12:00.928992 4354911744 caching.py:212] Received 0 predictions from model
I0323 23:12:00.929049 4354911744 app.py:187] Requested types: ['MulticlassPreds']
I0323 23:12:00.929118 4354911744 app.py:197] Will return keys: {'probas'}
I0323 23:12:00.929338 4354911744 _internal.py:113] 127.0.0.1 - - [23/Mar/2022 23:12:00] "POST /get_preds?model=sst&dataset_name=sst_dev&requested_types=MulticlassPreds HTTP/1.1" 200 -
I0323 23:12:00.929982 4354911744 app.py:138] 1 of 1 inputs sent as IDs; reconstituting from dataset 'sst_dev'
I0323 23:12:00.930073 4354911744 caching.py:202] CachingModelWrapper 'sst': misses (dataset=sst_dev): []
I0323 23:12:00.930128 4354911744 caching.py:204] CachingModelWrapper 'sst': 0 misses out of 1 inputs
I0323 23:12:00.930179 4354911744 caching.py:210] Prepared 0 inputs for model
I0323 23:12:00.930235 4354911744 caching.py:212] Received 0 predictions from model
I0323 23:12:00.930291 4354911744 app.py:187] Requested types: ['Tokens', 'AttentionHeads']
I0323 23:12:00.930380 4354911744 app.py:197] Will return keys: {'layer_5_attention', 'layer_9_attention', 'layer_11_attention', 'tokens', 'layer_1_attention', 'layer_0_attention', 'layer_8_attention', 'layer_7_attention', 'layer_3_attention', 'layer_6_attention', 'layer_10_attention', 'layer_4_attention', 'layer_2_attention'}
I0323 23:12:02.976485 4354911744 _internal.py:113] 127.0.0.1 - - [23/Mar/2022 23:12:02] "POST /get_preds?model=sst&dataset_name=sst_dev&requested_types=Tokens,AttentionHeads HTTP/1.1" 200 -
jameswex commented 2 years ago

To return values from predict_minibatch, need to convert that tensor([0.6403, 0.3597]) into a raw array of just [0.6403, 0.3597] as opposed to a tensor

pratikchhapolika commented 2 years ago

To return values from predict_minibatch, need to convert that tensor([0.6403, 0.3597]) into a raw array of just [0.6403, 0.3597] as opposed to a tensor

In which line of code?

jameswex commented 2 years ago

Not sure, you should check all your entries in batched_output to be sure they are normal python lists and not tensors. It might be the 'probas' entry that is the issue here.

pratikchhapolika commented 2 years ago

Not sure, you should check all your entries in batched_output to be sure they are normal python lists and not tensors. It might be the 'probas' entry that is the issue here.

Updated the code but getting this warning not error.

pratikchhapolika commented 2 years ago

@jameswex how can I launch the app in jupyter notebook itself instead as web page? How can I modify above code to do it?

pratikchhapolika commented 2 years ago

@jameswex second question is, how to get gradient visulaization in salience maps. In the above code?

pratikchhapolika commented 2 years ago

When I change to PCA viz its gives TypeError: (-0.7481077572209469+0j) is not JSON serializable.

jameswex commented 2 years ago

To run in a notebook, in your jupyter notebook, create your dataset and model classes and then create a LitWidget object with those objects and call render on it. An example can be seen here https://colab.sandbox.google.com/github/PAIR-code/lit/blob/main/lit_nlp/examples/notebooks/LIT_sentiment_classifier.ipynb in colab, but the code would be the same in jupyter.

If you want to see gradient-based salience methods in the LIT UI, then your model will need to have the apporpriate inputs and outputs to support them. See https://github.com/PAIR-code/lit/wiki/components.md#token-based-salience for details for having your model support the different salience methods.

pratikchhapolika commented 2 years ago

To run in a notebook, in your jupyter notebook, create your dataset and model classes and then create a LitWidget object with those objects and call render on it. An example can be seen here https://colab.sandbox.google.com/github/PAIR-code/lit/blob/main/lit_nlp/examples/notebooks/LIT_sentiment_classifier.ipynb in colab, but the code would be the same in jupyter.

If you want to see gradient-based salience methods in the LIT UI, then your model will need to have the apporpriate inputs and outputs to support them. See https://github.com/PAIR-code/lit/wiki/components.md#token-based-salience for details for having your model support the different salience methods.

OK. 

Also how to overcome TypeError: (-0.7481077572209469+0j) is not JSON serializable.
jameswex commented 2 years ago

The model and dataset code shouldn't change for notebooks. It's just that you create a LitWidget with the model and datasets, instead of a Server. Then you call render on the widget object.

I'm not sure about the root cause of that specific error. It's most likely that your predict_minibatch fn is returning some value for one of its fields for each example that isn't a basic, JSON-serializable type.

pratikchhapolika commented 2 years ago

The model and dataset code shouldn't change for notebooks. It's just that you create a LitWidget with the model and datasets, instead of a Server. Then you call render on the widget object.

I'm not sure about the root cause of that specific error. It's most likely that your predict_minibatch fn is returning some value for one of its fields for each example that isn't a basic, JSON-serializable type.

Converted all to list. Still same error.

def predict_minibatch(self, inputs):
        # Preprocess to ids and masks, and make the input batch.
        encoded_input = self.tokenizer.batch_encode_plus([ex["sentence"] for ex in inputs],return_tensors="pt",add_special_tokens=True,max_length=512,padding="longest",truncation="longest_first")

        # Check and send to cuda (GPU) if available
        if torch.cuda.is_available():
            self.model.cuda()
            for tensor in encoded_input:
                encoded_input[tensor] = encoded_input[tensor].cuda()
        # Run a forward pass.
        with torch.no_grad():  # remove this if you need gradients.
            out: transformers.modeling_outputs.SequenceClassifierOutput = self.model(**encoded_input)
            unused_attentions = out.attentions
            # print(unused_attentions)
            # print(type(unused_attentions))

        # Post-process outputs.
        batched_outputs = {
            "probas": torch.nn.functional.softmax(out.logits, dim=-1).tolist(),
            "input_ids": encoded_input["input_ids"],
            "ntok": torch.sum(encoded_input["attention_mask"], dim=1).tolist(),
            "cls_emb": out.hidden_states[-1][:, 0].tolist(),  # last layer, first token
        }

        for i in range(len(unused_attentions)):
            batched_outputs[f"layer_{i:d}_attention"] = unused_attentions[i].detach().numpy()

        # unbatched_outputs = utils.unbatch_preds(batched_outputs)
        # Return as NumPy for further processing.

        # for k, v in batched_outputs.items():
        #     print("batched_output")
        #     print(v)
        #     print(type(v))
        detached_outputs = {k: v for k, v in batched_outputs.items()}
        # print("detached_outputs")
        # print(detached_outputs)
        # Unbatch outputs so we get one record per input example.
        for output in utils.unbatch_preds(detached_outputs):
            ntok = output.pop("ntok")
            output["tokens"] = self.tokenizer.convert_ids_to_tokens(output.pop("input_ids")[1:ntok - 1])
            # print('output["tokens"]')
            # print(output)
            yield output
iftenney commented 2 years ago

Can you print the contents of batched_outputs, including types?

The error above:

TypeError: (-0.7481077572209469+0j) is not JSON serializable.

Looks like the value is a complex number a+bj, which is probably why it's not able to be serialized. NumPy arrays of floats should be fine, though; they'll be automatically converted to lists here: https://github.com/PAIR-code/lit/blob/main/lit_nlp/lib/serialize.py#L32

pratikchhapolika commented 2 years ago

Can you print the contents of batched_outputs, including types?

The error above:

TypeError: (-0.7481077572209469+0j) is not JSON serializable.

Looks like the value is a complex number a+bj, which is probably why it's not able to be serialized. NumPy arrays of floats should be fine, though; they'll be automatically converted to lists here: https://github.com/PAIR-code/lit/blob/main/lit_nlp/lib/serialize.py#L32

@iftenney here is the outputs.

batched_outputs after for loop

for i in range(len(unused_attentions)):
            batched_outputs[f"layer_{i:d}_attention"] = unused_attentions[i].detach().numpy()
{'probas': [[0.6018652319908142, 0.3981347680091858], [0.5785479545593262, 0.42145204544067383], [0.6183280348777771, 0.3816719651222229], [0.6127758026123047, 0.3872241675853729]], 'input_ids': tensor([[  101,  2079,  2017,  2031,  1037, 19085,  2030,  1037, 12436, 20876,
          1029,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2026,  5980,  2097,  5091,  2022,  3407,  2085,   999,   999,
          6293,  2026,  2524,  1010, 14908,  2266,  2046,  2115, 12436, 20876,
          3531,   999,   999,   999,   999,   999,   999,  1029, 10047,  2013,
          3742,  1057,  1029,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1045,  2031,  1037, 19085,  1012,  1045,  2215,  2000, 13988,
          1012,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  8840,  2140,  8700,  3348, 12436, 20876,  3398,  2054,  1996,
         12436, 20876,  2054,  2003,  2009,  2170,  1045,  2409,  1057,  1996,
         12436, 20876,  2087,  2450,  2655,  2009,  1037, 22418,  3398,  4521,
          4596,  2012,  2026,  2814, 17710, 13668,  2054,  2106,  2017,  1998,
          2115,  3611,  2079,  2253,  2041,  2000,  4521,  2253,  2005,  1037,
          3298,  2074,  2985,  2051,  2362,  2042,  2785,  1997, 26352,  2098,
          2041,  2651,  2339,  1029,  2049,  7929,  2073,  2106,  2017,  4553,
          2055,  3348,  4033,  1005,  1056,  2428,  2021,  2113,  2070,  2616,
          2073,  2106,  1057,  2175,  1029,  7592,  1029,   102]]), 'ntok': [12, 34, 12, 88], 'cls_emb': [[-0.014076177030801773, -0.0728173702955246, -0.078043133020401, 0.0938369482755661, -0.17423537373542786, 0.07189369201660156, 0.6690779328346252, 1.1941571235656738, -0.5418111085891724, 0.09891873598098755, 0.34711796045303345, -0.3437187671661377, -0.1604285091161728, -0.10622479021549225, 0.3024073839187622, 0.12053345888853073, -0.01676577888429165, ......]]

'layer_0_attention': array([[[[4.81520668e-02, 4.21391986e-02, 2.80070100e-02, ...,
          0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
         [1.28935039e-01, 3.51342373e-02, 8.70151743e-02, ...,
          0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
         [1.03371695e-01, 5.93042485e-02, 6.06599301e-02, ...,
          0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
         ..., dtype=float32), 'layer_1_attention': array([[[[3.56219172e-01....

batched_output key and value [[0.6018652319908142, 0.3981347680091858], [0.5785479545593262, 0.42145204544067383], [0.6183280348777771, 0.3816719651222229], [0.6127758026123047, 0.3872241675853729]] <class 'list'>


batched_output key and value

 for k, v in batched_outputs.items():
            print("batched_output key and value")
            print(v)
            print(type(v))
            print("*******************************************")
tensor([[  101,  2079,  2017,  2031,  1037, 19085,  2030,  1037, 12436, 20876,
          1029,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2026,  5980,  2097,  5091,  2022,  3407,  2085,   999,   999,
          6293,  2026,  2524,  1010, 14908,  2266,  2046,  2115, 12436, 20876,
          3531,   999,   999,   999,   999,   999,   999,  1029, 10047,  2013,
          3742,  1057,  1029,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1045,  2031,  1037, 19085,  1012,  1045,  2215,  2000, 13988,
          1012,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  8840,  2140,  8700,  3348, 12436, 20876,  3398,  2054,  1996,
         12436, 20876,  2054,  2003,  2009,  2170,  1045,  2409,  1057,  1996,
         12436, 20876,  2087,  2450,  2655,  2009,  1037, 22418,  3398,  4521,
          4596,  2012,  2026,  2814, 17710, 13668,  2054,  2106,  2017,  1998,
          2115,  3611,  2079,  2253,  2041,  2000,  4521,  2253,  2005,  1037,
          3298,  2074,  2985,  2051,  2362,  2042,  2785,  1997, 26352,  2098,
          2041,  2651,  2339,  1029,  2049,  7929,  2073,  2106,  2017,  4553,
          2055,  3348,  4033,  1005,  1056,  2428,  2021,  2113,  2070,  2616,
          2073,  2106,  1057,  2175,  1029,  7592,  1029,   102]])
<class 'torch.Tensor'>
*******************************************
batched_output key and value
[12, 34, 12, 88]
<class 'list'>

detached_outputs

detached_outputs = {k: v for k, v in batched_outputs.items()}
        print("detached_outputs")
{'probas': [[0.6018652319908142, 0.3981347680091858], [0.5785479545593262, 0.42145204544067383], [0.6183280348777771, 0.3816719651222229], [0.6127758026123047, 0.3872241675853729]], 'input_ids': tensor([[  101,  2079,  2017,  2031,  1037, 19085,  2030,  1037, 12436, 20876,
          1029,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2026,  5980,  2097,  5091,  2022,  3407,  2085,   999,   999,
          6293,  2026,  2524,  1010, 14908,  2266,  2046,  2115, 12436, 20876,
          3531,   999,   999,   999,   999,   999,   999,  1029, 10047,  2013,
          3742,  1057,  1029,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1045,  2031,  1037, 19085,  1012,  1045,  2215,  2000, 13988,
          1012,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  8840,  2140,  8700,  3348, 12436, 20876,  3398,  2054,  1996,
         12436, 20876,  2054,  2003,  2009,  2170,  1045,  2409,  1057,  1996,
         12436, 20876,  2087,  2450,  2655,  2009,  1037, 22418,  3398,  4521,
          4596,  2012,  2026,  2814, 17710, 13668,  2054,  2106,  2017,  1998,
          2115,  3611,  2079,  2253,  2041,  2000,  4521,  2253,  2005,  1037,
          3298,  2074,  2985,  2051,  2362,  2042,  2785,  1997, 26352,  2098,
          2041,  2651,  2339,  1029,  2049,  7929,  2073,  2106,  2017,  4553,
          2055,  3348,  4033,  1005,  1056,  2428,  2021,  2113,  2070,  2616,
          2073,  2106,  1057,  2175,  1029,  7592,  1029,   102]]), 'ntok': [12, 34, 12, 88], 'cls_emb': [[-0.014076177030801773, -0.0728173702955246, -0.078043133020401, 0.0938369482755661, -0.17423537373542786, 0.07189369201660156, 0.6690779328346252, 1.1941571235656738, -0.5418111085891724, 0.09891873598098755, 0.34711796045303345, -0.3437187671661377, -0.1604285091161728, -0.10622479021549225, 0.3024073839187622, 0.12053345888853073, -0.01676577888429165, 0.67
iftenney commented 2 years ago

Thanks, all of those values look okay although the indentation is very strange so I could be missing something. Can you post the error you're still seeing? You might try running under pdb and seeing which field it's coming from.

aryan1107 commented 2 years ago

@pratikchhapolika To visualize Huggingface models you can start by adding any basic models directly to LIT. Here is one example which I did using Huggingface.... the code might help https://github.com/PAIR-code/lit/pull/691