PAIR-code / lit

The Learning Interpretability Tool: Interactively analyze ML models to understand their behavior in an extensible and framework agnostic interface.
https://pair-code.github.io/lit
Apache License 2.0
3.46k stars 352 forks source link

Does it support embedding as input #1420

Open ltroin opened 6 months ago

ltroin commented 6 months ago

Hello, I am wondering is there any plan of supporting torch tensors as data input...?

RyanMullins commented 6 months ago

Hi @ltroin, can you be more specific about the model or component you're trying to use? LITs model wrappers generally assume the data will be sent in as represented in the dataset, so typically text, numbers, or image bytes, not pre-computed tensors. Components (interpreters, projectors, etc) are more flexible but also tend toward assuming the incoming data will be the same as the dataset or whatever the model outputs. That said, we should be able to help you adapt/customize the LIT code you want to use to fit your needs.

iftenney commented 6 months ago

You can use tensor data for Embeddings, TokenEmbeddings, or similar types - you will just want to be sure to convert it to numpy arrays as LIT expects plain-old-data that can be serialized between Python and Javascript.

ltroin commented 6 months ago

Hi @ltroin, can you be more specific about the model or component you're trying to use? LITs model wrappers generally assume the data will be sent in as represented in the dataset, so typically text, numbers, or image bytes, not pre-computed tensors. Components (interpreters, projectors, etc) are more flexible but also tend toward assuming the incoming data will be the same as the dataset or whatever the model outputs. That said, we should be able to help you adapt/customize the LIT code you want to use to fit your needs.

Thank you for the quick response. I'm currently working with the LLama model, which accepts two types of inputs: tokenized text or input embeddings. I'm interested in using the input embeddings to analyze and visualize the patterns of attention weights across different layers and create a salience map based on "tokens", where each "token" corresponds to a row of input embeddings.

ltroin commented 6 months ago

You can use tensor data for Embeddings, TokenEmbeddings, or similar types - you will just want to be sure to convert it to numpy arrays as LIT expects plain-old-data that can be serialized between Python and Javascript.

Does this imply that the LIT framework will interpret this numpy array as input embeddings rather than raw text?

RyanMullins commented 6 months ago

Does this imply that the LIT framework will interpret this numpy array as input embeddings rather than raw text?

tl;dr -- Probably not. NumPy arrays are not raw text and LIT expects to operate over language-native string values when dealing with strings, so if you passed a NumPy array in a place where it was expecting a Python str you would most likely get a ValueError.

Longer explanation -- LIT operates over a JSON Object representation of examples. Since JSON has very limited support for types, LIT provides its own type system that our TypeScript and Python codebases use to decide how to handle different values in the JSON Objects we pass around. Model and Dataset classes declare the shape (i.e., field names and types) of the JSON Objects that provide/accept as Specs, and components (interpreters, generators, metrics, and UI modules) look for specific LIT types in these Specs to determine compatibility and decide how to handle them. As above, you should expect a ValueError, etc. if you pass around a non-conforming value for the expected type in LIT.

I'm currently working with the LLama model

We don't have a wrapper for Llama in an official release yet, but we're working on adding one in https://github.com/PAIR-code/lit/pull/1421. It's designed to take raw text as the input and then the HF implementation handles tokenization, embedding, generation, etc. It would be possible for you to subclass this (or write your own) so that the wrapper class takes embeddings as input instead of raw text.

I'm interested in using ... [analyzing and visualizing] patterns of attention weights across different layers and create a salience map based on "tokens", where each "token" corresponds to a row of input embeddings.

LIT provides a Sequence Salience module that renders a salience map over tokenized text. Is that the kind of thing you're looking for, or do you want to display a salience map over a matrix of shape <float>(num_tokens, hidden_dims), or maybe even something else?

ltroin commented 4 months ago

Hi, sorry for the late reply.

Thank you so much for the detailed explanation.

We don't have a wrapper for Llama in an official release yet, but we're working on adding one in https://github.com/PAIR-code/lit/pull/1421.

This feature will be awesome! And yes, I also want to display a salience map over a matrix of shape (num_tokens, hidden_dims), where certain num_tokens are highlighted.