stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy-docs.vercel.app/
MIT License
14.79k stars 1.13k forks source link

Data Extraction of multiple fields with suggestions from TypedPredictors #1060

Open BlueKiji77 opened 2 months ago

BlueKiji77 commented 2 months ago

I have been trying to extract data (title, question answered, entities, summary) from documents chunks.

I believed typed predictors would be good for this, but I keep running into "Too many retries " error.

This left me wondering if the way I defined the program might be the problem and I do not properly understand how Typed Predictors work or I might be trying to do too much in a single program.

WHAT THE PROGRAM SHOULD DO

Extract the following fields: Title, Summary, QuestionAnswered, Entities using DataExtractionSignature.

For each field in the output to Typed Predictor, the program assesses 3 to 5 properties with Suggest module, AssessDataSignature

For ouput validation


from pydantic import BaseModel, Field, ValidationError
from typing import List, Type, Any

from inspect import isclass

def get_min_length(model: Type[BaseModel]):
    min_length = 0
    for key, field in model.model_fields.items():
        min_length += len(key)
        if not isclass(field.annotation):
             continue
        if issubclass(field.annotation, BaseModel):
            min_length+=get_min_length(field.annotation)
    return min_length

# Data Extraction model and signature

class ExtractorInput(BaseModel):
    text: str = Field(description="The text from which data is to be extracted.")

class ExtractorOutput(BaseModel):
    title: str = Field(description="Appropriate title for the text.")
    summary: str = Field(description="Appropriate summary for the text.")
    entities: List[str] = Field(description="List of 3 to 5 key entities from the text.")
    questions_answered: List[str] = Field(description="List of 2 or 3 questions answered by the text.")

    @classmethod
    def model_validate_json(
        cls,
        json_data: str,
        *,
        strict: bool | None = None,
        context: dict[str, Any] | None = None
    ) -> "User":
        try:
            __tracebackhide__ = True
            return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context)
        except ValidationError:
            min_length = get_min_length(cls)
            for substring_length in range(len(json_data), min_length-1, -1):
                for start in range(len(json_data)-substring_length+1):
                    substring = json_data[start:start+substring_length]
                    try:
                        __tracebackhide__ = True
                        res = cls.__pydantic_validator__.validate_json(substring, strict=strict, context=context)
                        return res
                    except ValidationError as exc:
                        last_exc = exc
                        pass
        raise ValueError("Could not find valid json") from last_exc
`

from dspy import Signature, InputField, OutputField

class DataExtractionSignature(Signature):
    input: ExtractorInput = InputField()
    output: ExtractorOutput = OutputField(desc="Metadata from input in JSON Format.")

Assessment model and signature


from dspy import Signature, InputField, OutputField, Suggest
from typing import List, Dict,Union

def is_assessment_yes(assessment_answer):
    """Check if the first word of the assessment answer is 'yes'."""
    return assessment_answer.split()[0].lower() == 'yes'

class AssessmentInput(BaseModel):
    assessed_text: str = Field()
    assessment_criterias: List[str] = Field(description="Criteria the assessment object will bw evaluated upon") 
    assessment_criteria_questions: List[str] = Field(description="Assessment questions by criteria")
    assessment_object: Union[str, List[str]] = Field(description="Object of the assessment")

class AssessmentOutput(BaseModel):
    assessment_output: Dict[str, bool]  = Field(description="key value pairs {assssment_criteria: wether assessment was passed}")

class AssessDataSignature(Signature):
    """Assess some particular information extracted or derived from a text"""
    input: AssessmentInput = InputField()
    output: AssessmentOutput = OutputField(desc="assessment in JSON format")

Metric for Evaluation and Assessment

### Summary Assessment
def summary_metrics(example, pred, return_details=False):
    text = pred.text 
    summary = pred.extracted_data.summary

    comprehensiveness = """does the summary include all the key points and main ideas from the original text, without omitting any critical information? 
                            Answer by True or False"""

    coherence = """does the summary flow logically and make sense as a standalone piece of writing? Answer by True or False"""

    independence = """does the summary maintain an objective, unbiased tone and use the summarizer's own words rather than directly quoting the original text? Answer by True or False"""

    accuracy = """does the summary precisely capture the meaning and intent of the original text, without any distortions or misrepresentations? Answer by True or False"""

    relevance = """does the summary focus on the key points that are most relevant to the get a complete overview of what is addressed in the text? Answer by True or False"""

    clarity = """is the summary easy to understand and free of ambiguity? Answer by True or False"""

    criterias = [comprehensiveness, coherence, independence, accuracy, relevance, clarity]
    criterias_names = ['comprehensiveness', 'coherence', 'independence', 'accuracy', 'relevance', 'clarity']
    criterias_suggestions = [
                            "It should provide a complete overview of the text's content.",
                            "It should use transitions effectively and have a clear structure.",
                            "It should not introduce any new ideas or criticisms.",
                            "It should not include any factual errors.",
                            "It should only include subjects included in the text",
                            "It should use clear, concise language and avoid jargon or complex sentence structures"]

    assessment_input = AssessmentInput(assessed_text=text,
                                      assessment_criterias=criterias_names,
                                      assessment_criteria_questions=criterias,
                                      assessment_object=summary)
    summary_assessment = TypedPredictor(AssessDataSignature)(input=assessment_input)

    if return_details:
        return summary_assessment, criterias_names, criterias_suggestions
    else:
        values = []
        for key, value in dict(summary_assessment.output.assessment_output).items():
            print(f"value: {value}")
            values.append(int(value))
        print(f"values: {values}")
        score = sum(values) / len(values)
    return score

### Entities Assessment
def entities_metrics(example, pred, return_details=False):
    text = pred.text 
    entities = pred.extracted_data.entities

    accuracy = """do the entities accurately represent the concepts and objects mentioned in the original text, without any distortions or misrepresentations? Answer by True or False"""

    consistency = """are the entities a standardized set of entity types and follow a consistent naming convention? Answer by True or False"""

    relevance = """are the entities relevant to the context and purpose of capture of semantic information for information retrieval? Answer by True or False"""

    clarity = """are the entities easy to understand and free of ambiguity? Answer by True or False"""

# Similar to summary_metric above 
...
...
...
...
...
...
    return score

### Questions Answered Assessment
def qas_metrics(example, pred, return_details=False):
    text = pred.text 
    questions_answered = pred.extracted_data.questions_answered

    complexity = """do each questions need more than a Yes or No answer to be properly answered? Answer by True or False"""

    specificity = """are each questions focused on a single issue or concept in the text? Answer by True or False"""

    appropriateness = """are each questions tailored to the specific text and its context, not generic questions that could apply to any text? Answer by True or False"""

   # Similar to summary_metric above 
...
...
...
...
...
...
    return score

### Title Assessment
def title_metrics(example, pred, return_details=False):
    text = pred.text 
    title = pred.extracted_data.title

    clarity = """is the title clear, concise, and easy to understand? Answer by True or False"""

    relevance = """is the title relevant to the main topic? Answer by True or False"""

    alignment = """is the title accurately reflect the content and scope of the text? Answer by True or False"""

  # Similar to summary_metric above 
...
...
...
...
...
...
    return score

def overall_metric(example, pred):
    print("==="*50)
    print(f"Example: \ntext:{pred.text} \nextracted_data: {pred.extracted_data} ")

    print('\nTitle Evaluation')
    title_score = title_metrics(example, pred)
    print(f"Score: {title_score}")

    print("\nQAS Evaluation")
    qas_score = qas_metrics(example, pred)
    print(f"Score: {qas_score}")

    print("\nEntities Evaluation")
    entities_score = entities_metrics(example, pred)
    print(f"Score: {entities_score}")

    print("\nSummary Evaluation")
    summary_score = summary_metrics(example, pred)
    print(f"Score: {summary_score}")

    score = (title_score + qas_score + entities_score + summary_score) / 4
    print(f"Overall Score: {score}")
    return score

The program itself


import json
from dspy import Signature, InputField, OutputField, Suggest, TypedChainOfThought, TypedPredictor, Module, Prediction    
import gc 
import torch

class DataExtractorPlusAssertion(Module):
    def __init__(self, max_retries=3):
        super().__init__()
        self.extractor = TypedPredictor(signature=DataExtractionSignature, max_retries=max_retries)

    def forward(self, text):
        extractor_input = ExtractorInput(text=text)
        pred = self.extractor(input=extractor_input)
        pred = Prediction(text=text, extracted_data=pred.output)

        #################### METRICS  ASSESSMENT BEGIN ####################
        ### Summary metric
        summary_evals, summary_evals_names, summary_evals_suggestions = summary_metrics(None, pred, return_details=True)
        summary_eval_output = dict(summary_evals.output.assessment_output)

        ### Title metric
        title_evals, title_evals_names, title_evals_suggestion = title_metrics(None, pred, return_details=True)
        title_eval_output = dict(title_evals.output.assessment_output)

        ### QAs metric
        qas_evals, qas_evals_names, qas_evals_suggestion = qas_metrics(None, pred, return_details=True)
        qas_eval_output = dict(qas_evals.output.assessment_output)

        # Entities metric
        entities_evals, entities_evals_names, entities_evals_suggestion = entities_metrics(None, pred, return_details=True)
        entities_eval_output = dict(entities_evals.output.assessment_output)

        #################### dspy.Suggest ####################
        # Test if you can really zip dict like that
        metric_result_triplets = [(summary_eval_output, summary_evals_names, summary_evals_suggestions),
                                     (title_eval_output, title_evals_names, title_evals_suggestion),
                                     (qas_eval_output, qas_evals_names, qas_evals_suggestion),
                                     (entities_eval_output, entities_evals_names, entities_evals_suggestion)]

        for metric_result_triplet in metric_result_triplets:
            eval_assessments, eval_names, eval_suggestions = metric_result_triplet
            for eval_assessment, eval_name, eval_suggestion in zip(dict(eval_assessments).items(), eval_names, eval_suggestions):
                Suggest(eval_assessment[1], f"{eval_suggestion}", target_module=DataExtractionSignature)   

        return pred

Also optimize_signature seems really obscure to me. It does not seem to optimize prompts from my inspection of my LLM history. It throws the "Too many retries" after like 4 iterations. Can someone point me in the right direction here.

arnavsinghvi11 commented 1 month ago

Hi @BlueKiji77 , I would reocommend trying out the standard ChainOfThought and seeing if that resolve some of the issues here. Part of the reason you run into "Too many retries" is the model itself is not adhering to requirements #957 , not your pipeline being too complex.