PAIR-code / lit

The Learning Interpretability Tool: Interactively analyze ML models to understand their behavior in an extensible and framework agnostic interface.
https://pair-code.github.io/lit
Apache License 2.0
3.46k stars 352 forks source link

Running LIT for a fine-tuned LLM #1418

Closed vincentabraham closed 6 months ago

vincentabraham commented 6 months ago

Hi, I'm a beginner with using LLMs and I would like to use LIT to gain insights on how my fine-tuned Llama-13b-chat model makes predictions on binary classification problems. I found the examples given really confusing. Could anyone guide me on how do I go about setting up LIT?

RyanMullins commented 6 months ago

Hi @vincentabraham! Happy to help work through how to set up LIT for you use case.

I have a few questions for you to help get started:

Your answers to the questions above will help me guide the process more specifically, but the general workflow will look something like the following. The key thing to keep in mind is that LIT is framework agnostic, and requires implementing wrappers around models and datasets to convert data into a JSON structure that LIT's built-in interpreters, generators, and metrics know how to work with. For simplicity, let's assume we're working in a Colab so we don't have to worry about multiple files.

  1. Load the model into memory using your framework of choice.
    • Note that some models may not fit in memory and might require remote hosting for inference.
  2. Identify (or create) and instantiate the appropriate LIT model wrapper(s) for your model.
  3. Identify (or create) and instantiate the appropriate LIT dataset wrapper(s) for your binary classification dataset.
    • For fine-tuning it can be useful to instantiate LIT dataset wrappers for the training and test/eval/validation slices of your dataset.
  4. Collect the model(s) and dataset(s) into dictionaries and pass them to the LitWidget initializer to create a LIT server in Colab.
  5. Use the LitWidget.render() method to load the LIT UI in Colab for interactive analysis.

Once we get this basic process working, we can look at different ways to inspect the model's performance such are metrics (many of which are built into LIT) or input feature attribution methods. These decisions are driven by the affordances of the model serving framework you're using, for example llama.cpp does not provide easy access to gradients (AFAIK) whereas HF Transformers and KerasNLP do.

I recommend checking out our recent guides to model analysis, part of Google's Responsible Generative AI Toolkit, and including a walkthrough of using LIT to aid in prompt engineering practices (this is also available as an interactive Codelab) and the accompanying Colab. These were written as examples of how to use LIT to analyze Gemma specifically, but translate directly to Llama and other LLMs.

vincentabraham commented 6 months ago

Hi Ryan, Thanks a lot for your response. Unfortunately, I can't share the code which was used for fine-tuning, but this was the article that I referred to: https://medium.com/@geronimo7/finetuning-llama2-mistral-945f9c200611. I have the fine-tuned model locally and I could upload it to HuggingFace as well. I used the HF Transformers framework for fine-tuning. I've created some customized datasets for training, validation and testing, which I'm using for fine-tuning and inference. What I want to do is, gain insights from the on which portions of the user prompt are influencing it's predictions.

RyanMullins commented 6 months ago

Thanks for the details.

I want to ... gain insights ... on which portions of the user prompt are influencing it's predictions.

Based on this, you'll probably to configure LIT to use Sequence Salience and classification metrics (Python impl that powers the metrics and confusion matix modules in the UI). Sequence salience methods will let you dig into the influential portions of the prompt designs, and the classification metrics will help you slice your model's outputs (e.g., on your eval set) to find interesting subset (FT/FP sets).

Getting this working will require some semi-specialized mechanics. Let's start with the good news/easy stuff...

Since you're using a dataset of your own creation, you'll need to define you own wrapper. I'm not sure what your fields are, but. your dataset Spec should be roughly as follows:

from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types

class YourDatasetWrapper(lit_dataset.Dataset):

  # ... the rest of the implementation ...

  def spec() -> lit_types.Spec:
    return {
        "prompt": lit_types.TextSegment(),
        "target": lit_types.TextSegment(),
        "target_cls": lit_types.CategoryLabel(vocab=[0, 1])
    }

Our classification metrics require the presence of certain LitTypes in the model and datasets Specs, specifically a MulticlassPreds type. The existing wrappers that LIT provides for Causal LMs don't have these because they're designed for the general case usage (i.e., TextSegment in GeneratedText or GeneratedTextCandidates out), but adding them is really easy, you can define a subclass of the model wrapper and add a MulitclassPreds type to the output spec, and then map the generated text to a one-hot representation of your possible labels. Below is some reference code to get the model's output spec in the shape you want, though the predict function is slightly more complicated.

from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types

class YourModelWrapper(SomeCasualLMWrapper):

  # ... the rest of the implementation ...

  def output_spec() -> lit_types.Spec:
    base_spec = super().output_spec()
    return base_spec | {
        "pred_cls":  lit_types.MulticlassPreds(parent="target_cls", vocab=[0, 1])
    }

Okay now for the (slightly) bad news...

LIT doesn't have a built-in model wrapper for Llama yet, but it is something we're working on. If you want to wait for us to implement this it might be a little while (ETA end of month-ish on the dev branch, probably longer before it's uploaded in a release on PyPI). If you don't want to wait, you could try implementing your own wrapper based on LIT's GPT-2 wrapper. The HF interfaces for causal LMs is pretty standardized, so adapting it should be relatively straight-forward.

vincentabraham commented 6 months ago

Thank you for the help. I'll look into it.