stanfordnlp / dspy

DSPy: The framework for programming—not prompting—language models
https://dspy.ai
MIT License
19.25k stars 1.47k forks source link

Can DSPy be used with a discriminative NLI model like BERT #1756

Open umarbutler opened 3 weeks ago

umarbutler commented 3 weeks ago

Suppose we have a BERT NLI model trained to zero-shot classify texts given a prompt such as, for example, 'This text relates to cars.' Is it possible to use DSPy to optimise the prompt for such a model?

okhat commented 3 weeks ago

Thanks @umarbutler ! I initially misunderstood the question.

DSPy doesn't currently support BERT-like models (any non-autoregressive models) natively.

However, it's indeed possible and semi-straightforward to set up prompt optimization for BERT-style models. It's just going to require one level of indirection.

import dspy
from typing import Callable, Literal

NLI = Literal["entailment", "contradiction", "neutral"]

class PrompterNLI(dspy.Signature):
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    prompt: str = dspy.OutputField(desc="prompt for a BERT model that conducts NLI")

class PromptedBERT(dspy.Module):
    def __init__(self, bert_nli_classifier: Callable[[str], NLI]):
        self.bert_nli_classifier = bert_nli_classifier
        self.prompter = dspy.ChainOfThought(PrompterNLI)

    def forward(self, premise: str, hypothesis: str) -> NLI:
        prompt = self.prompter(premise=premise, hypothesis=hypothesis)
        return self.bert_nli_classifier(prompt)

optimizer = dspy.MIPROv2(metric=METRIC, auto="medium", num_threads=THREADS)
optimized_bert = optimizer.compile(PromptedBERT(), trainset=TRAINSET, requires_permission_to_run=False)

Note that you need to specify METRIC, THREADS, TRAINSET, and build a nice wrapper bert_nli_classifier: Callable[[str], NLI] to use this.

Now here's the catch: the self.prompter currently depends on the premise and hypotheses, which means it'll be invoked afresh every time. This could be expensive. Can this be fixed? Yes, you can instead build a model that's input-invariant.

I'll send that along shortly.

okhat commented 3 weeks ago
class PrompterNLI(dspy.Signature):
    """
    Prepare a prompt with {premise} and {hypothesis} for an NLI BERT model.
    Use one [CLS] and two [SEP] tokens. Introduce creative prefixes or suffixes for the variables.
    """

    variables: list[str] = dspy.InputField()
    prompt: str = dspy.OutputField(desc="a pattern on which we may call .format()")

class PromptedBERT(dspy.Module):
    def __init__(self, bert_nli_classifier: Callable[[str], NLI]):
        self.bert_nli_classifier = bert_nli_classifier
        self.prompter = dspy.ChainOfThought(PrompterNLI, temperature=1.0)

    def forward(self, premise: str, hypothesis: str) -> NLI:
        prompt = self.prompter(variables=['premise', 'hypothesis']).prompt
        prompt = prompt.strip('"').format(premise=premise, hypothesis=hypothesis)
        return self.bert_nli_classifier(prompt)

Here, self.prompter always receives the same inputs, which means you can cache it after optimization.

Here's one sample output before optimization:

[CLS] Given the statement: {premise} [SEP] Assess the claim: {hypothesis} [SEP] Determine if the claim logically follows from the statement.

We've never done this before, though, so who knows how BERT will respond to these prompts!

umarbutler commented 2 weeks ago

Shukran Omar, this was a great starting point for me.

It might be helpful for me to give you a bit more detail about what I'm trying to do.

In particular, I have a zero-shot NLI classifier (eg, MoritzLaurer/deberta-v3-large-zeroshot-v2.0) that takes a piece of text (aka a premise) and a statement about that text (aka a hypothesis) and outputs a binary classification of whether or not the statement is supported by the text (aka entailment).

Sometimes, I might already have some labelled data which I can then use to optimise my statement.

For example, suppose I want to classify all of my blog posts by whether or not they relate to law. I could use a statement like 'This relates to law' but perhaps a better statement would be 'This blog post relates to law' or 'This blog post relates to the topic of law'. If I already have labelled data, I can simply try a bunch of different permutations of the same statement until I arrive at a statement that achieves the best Matthews' correlation coefficent or accuracy.

As you might imagine, this process can become quite involved if you need to repeat it for many different classification problems, models or hyperparameters.

I had a crack at adapting your code for this problem, however, it doesn't seem to be working because my training examples are missing a statement input. Is it possible to have DSPy come up with statements on its own without any examples of what an optimal statement looks like, solely by optimising for accuracy on the training set?

import dspy
import torch
import dspy.evaluate.metrics

from transformers import pipeline

class ZeroShotClassification(dspy.Signature):
    statement: str = dspy.InputField()
    text: str = dspy.InputField()
    supported: str = dspy.OutputField()

class ZeroShotClassifier(dspy.Module):
    def __init__(self, model_name: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> None:
        self.pipe = pipeline('zero-shot-classification', model = model_name, device = device)
        self.prompter = dspy.ChainOfThought(ZeroShotClassification, temperature = 1.)

    def forward(self, text: str) -> str:
        statement = self.prompter(text = text).statement
        statement = statement.replace("'", '').replace('"', '').strip()

        return ['not supported', 'supported'][self.pipe(text, statement, hypothesis_template = '{}')['score'] > 0.5]

# BEGIN CONFIG #
ZERO_SHOT_CLASSIFIER_NAME = 'MoritzLaurer/deberta-v3-large-zeroshot-v2.0'
LM_NAME = 'openai/gpt-4o-mini'
THREADS = 1
AUTO_MODE = 'medium'
EXPERIMENTAL = True
NEEDS_PERMISSION = False
DEVICE = 'cpu'

SENTIMENT_EXAMPLES = [ # The goal would be to have DSPy come up with a prompt like 'This text expresses positive sentiment towards DSPy.'
    dspy.Example(
        text = 'I love DSPy!',
        supported = 'supported',
    ),
    dspy.Example(
        text = 'I hate DSPy!',
        supported = 'not supported',
    ),
    dspy.Example(
        text = "I'm not sure how I feel about DSPy.",
        supported = 'not supported',
    ),
]
# END CONFIG #

SENTIMENT_EXAMPLES = [example.with_inputs('text') for example in SENTIMENT_EXAMPLES]

lm = dspy.LM(LM_NAME)
dspy.configure(lm = lm, experimental = EXPERIMENTAL)
zsc = ZeroShotClassifier(ZERO_SHOT_CLASSIFIER_NAME, DEVICE)
optimizer = dspy.MIPROv2(metric = dspy.evaluate.metrics.answer_exact_match, auto = AUTO_MODE, num_threads = THREADS)
optimized_zsc = optimizer.compile(zsc, trainset = SENTIMENT_EXAMPLES, requires_permission_to_run = NEEDS_PERMISSION)