Open felixgao opened 5 months ago
Adding a fuzzy model_validate_json
to your model could help - https://github.com/stanfordnlp/dspy/issues/1001#issuecomment-2110984085
from pydantic import BaseModel, Field
from inspect import isclass
def get_min_length(model: Type[BaseModel]):
min_length = 0
for key, field in model.model_fields.items():
min_length += len(key)
if not isclass(field.annotation):
continue
if issubclass(field.annotation, BaseModel):
min_length+=get_min_length(field.annotation)
return min_length
class ModelField(BaseModel):
key: str = Field(description="The key to be extracted from the context")
value: str | None = Field(description="the value to be extracted from the context")
truthful: bool = Field(
description="is a tuple contains the extracted value and a boolean indicating if the information is truthful and correct."
)
confidence: float = Field(ge=0, le=1, description="The confidence score for the answer")
class Fields(BaseModel):
fields: list[ModelField]
@classmethod
def model_validate_json(
cls,
json_data: str,
*,
strict: bool | None = None,
context: dict[str, Any] | None = None
) -> "User":
try:
__tracebackhide__ = True
return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context)
except ValidationError:
min_length = get_min_length(cls)
for substring_length in range(len(json_data), min_length-1, -1):
for start in range(len(json_data)-substring_length+1):
substring = json_data[start:start+substring_length]
try:
__tracebackhide__ = True
res = cls.__pydantic_validator__.validate_json(substring, strict=strict, context=context)
return res
except ValidationError as exc:
last_exc = exc
pass
raise ValueError("Could not find valid json") from last_exc
output = "```json\n{\"fields\": [{\"key\": \"Social Security Number\", \"value\": \"XXX-XX-2489\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"Employer's FED ID number\", \"value\": \"85-4006070\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"Employee's SSA number\", \"value\": \"XXX-XX-2489\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"Wages, tips, other comp.\", \"value\": \"2307.11\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"Federal income tax withheld\", \"value\": \"41.98\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"Social security wages\", \"value\": \"2129.03\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"Social security tax withheld\", \"value\": \"143.04\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"Medicare wages and tips\", \"value\": \"2307.11\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"Medicare tax withheld\", \"value\": \"33.45\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"State\", \"value\": \"AL\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"Employer's state ID no.\", \"value\": \"R010981778\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"State wages, tips, etc.\", \"value\": \"2307.11\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"State income tax\", \"value\": \"87.86\", \"truthful\": true, \"confidence\": 1.0}, {\"key\": \"Local income tax\", \"value\": null, \"truthful\": true, \"confidence\": 0.0}]}\n```"
Fields.model_validate_json(output)
Thank you Mike. I think from the pydantic side it works. However, after making the changes to the signatures, my code no longer works with DSPy.
/Users/ggao/github/ged/GED_MONOREPO/projects/gemini/.venv/lib/python3.11/site-packages/pydantic/_internal/_fields.py:200: UserWarning: Field name "fields" in "QAExtractionSignature" shadows an attribute in parent "Signature"
warnings.warn(
Traceback (most recent call last):
File "/Users/ggao/github/ged/GED_MONOREPO/projects/gemini/gemini/typed_doc_qa.py", line 58, in <module>
class QAExtractionSignature(Signature):
File "/Users/ggao/github/ged/GED_MONOREPO/projects/gemini/.venv/lib/python3.11/site-packages/dspy/signatures/signature.py", line 63, in __new__
cls._validate_fields()
File "/Users/ggao/github/ged/GED_MONOREPO/projects/gemini/.venv/lib/python3.11/site-packages/dspy/signatures/signature.py", line 79, in _validate_fields
raise TypeError(
TypeError: Field 'fields' in 'QAExtractionSignature' must be declared with InputField or OutputField. field.json_schema_extra=None
What changes were made to the signature?
no changes.
class QAExtractionSignature(Signature):
"""Your task is to extract the key-value pairs from the document that will follow the instructions on the field to extract.\nplease keep the content you extract truthful and correct based on the document text provided"""
document: str = InputField()
instruction: str = InputField(
description="The instruction on which field to extact and how to extract it."
)
fields: Fields = OutputField(
description="The list of fields extracted from the document text. IMPORTANT!! This must follow a semicolon separated list of values!"
)
I suspect you're interfering with some internal namespace - dspy.SignatureMeta
has a fields
property which i'm guessing is interfering with the fields
you're defining. In general its good to avoid using generic names for members if you can avoid it.
Upon further investigation, it looks like the model_validate_json
deserializer is wrapped by the below function:
def _unwrap_json(output):
output = output.strip()
if output.startswith("```"):
if not output.startswith("```json"):
raise ValueError("json output should start with ```json")
if not output.endswith("```"):
raise ValueError("Don't write anything after the final json ```")
output = output[7:-3].strip()
if not output.startswith("{") or not output.endswith("}"):
raise ValueError("json output should start and end with { and }")
return ujson.dumps(ujson.loads(output)) # ujson is a bit more robust than the standard json
which the user is not able to replace with their own deserializer. I will open a PR to make this optional
Thanks for investigating it further Mike. I am still confused on what is going on, it seems the way typed signature works is very brittle. I made the following change and I am back at the Too many retries problem
raise ValueError(
ValueError: ('Too many retries trying to get the correct output format. Try simplifying the requirements.', {'tax_from_fields': "ValueError('Could not find valid json')"})
def get_min_length(model: Type[BaseModel]):
min_length = 0
for key, field in model.model_fields.items():
min_length += len(key)
if not isclass(field.annotation):
continue
if issubclass(field.annotation, BaseModel):
min_length += get_min_length(field.annotation)
return min_length
class ModelField(BaseModel):
key: str = Field()
value: str = Field()
truthful: bool = Field(
description="is a tuple contains the extracted value and a boolean indicating if the inFieldsation is truthful and correct."
)
confidence: float = Field(ge=0, le=1, description="The confidence score for the answer")
class TaxFromFields(BaseModel):
tax_from: list[ModelField] = Field()
@classmethod
def model_validate_json(
cls, json_data: str, *, strict: bool | None = None, context: dict[str, Any] | None = None
) -> "list[ModelField]":
try:
__tracebackhide__ = True
return cls.__pydantic_validator__.validate_json(
json_data, strict=strict, context=context
)
except ValidationError:
min_length = get_min_length(cls)
for substring_length in range(len(json_data), min_length - 1, -1):
for start in range(len(json_data) - substring_length + 1):
substring = json_data[start : start + substring_length]
try:
__tracebackhide__ = True
res = cls.__pydantic_validator__.validate_json(
substring, strict=strict, context=context
)
return res
except ValidationError as exc:
last_exc = exc
pass
raise ValueError("Could not find valid json") from last_exc
class QAExtractionSignature(Signature):
"""Your task is to extract the key-value pairs from the document that will follow the instructions on the field to extract.\nplease keep the content you extract truthful and correct based on the document text provided"""
document: str = InputField()
instruction: str = InputField(
description="The instruction on which field to extact and how to extract it."
)
tax_from_fields: TaxFromFields = OutputField(
description="The list of fields extracted from the document text. IMPORTANT!! This must follow a semicolon separated list of values!"
)
class TypedQAExtraction(dspy.Module):
def __init__(self):
self.qa_extraction = dspy.functional.TypedPredictor(QAExtractionSignature)
def forward(self, document: str, instruction: str) -> TaxFromFields:
return self.qa_extraction(document=document, instruction=instruction).tax_from_fields
When looking at the trace. it seems the
Predict(QAExtractionSignature).forward
is sending in
{"document": "1 Wages, tips, other comp. 2307.11 2 Federal income tax withheld 41.98 3 Social security wages 2129.03 4 Social security tax withheld 143.04 5 Medicare wages and tips 2307.11 6 Medicare tax withheld 33.45 d Control number 0000034180 TM6 Dept. Corp. CTQF Employer use only 55394 c Employer's name, address, and ZIP code HUT AMERICAN GROUP LLC 6200 OAK TREE BLVD - SUITE 250 INDEPENDENCE, OH 44131 b Employer's FED ID number 85-4006070 a Employee's SSA number XXX-XX-2489 7 Social security tips 178.08 8 Allocated tips 9 10 Dependent care benefits 11 Nonqualified plans 12a DD | 69.71 14 Other 12b 12c 12d 13 Stat emp. Ret. plan 3rd party sick pay e/f Employee's name, address and ZIP code MARY A DAWKINS 1609 DEATSVILLE HWY MILLBROOK, AL 36054 15 State AL Employer's state ID no. R010981778 16 State wages, tips, etc. 2307.11 17 State income tax 87.86 18 Local wages, tips, etc. 19 Local income tax 20 Locality name City or Local Filing Copy W-2 Wage and Tax 2021 Statement OMB No. 1545-0008 Copy 2 to be filed with employee's City or Local Income Tax Return.", "instruction": "extract the following fields (Semicolon separated): Social Security Number;Employer's FED ID number;Employee's SSA number;Wages, tips, other comp.;Federal income tax withheld;Social security wages;Social security tax withheld;Medicare wages and tips;Medicare tax withheld;State;Employer's state ID no.;State wages, tips, etc.;State income tax;Local income tax", "error_tax_from_fields_0": "ValueError('Could not find valid json')", "new_signature": "StringSignature(document, instruction, error_tax_from_fields_0 -> tax_from_fields\n instructions='Your task is to extract the key-value pairs from the document that will follow the instructions on the field to extract.\\nplease keep the content you extract truthful and correct based on the document text provided'\n document = Field(annotation=str required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Document:', 'desc': '${document}', 'format': <function TypedPredictor._prepare_signature.<locals>.<lambda> at 0x15443ad40>})\n instruction = Field(annotation=str required=True description='The instruction on which field to extact and how to extract it.' json_schema_extra={'__dspy_field_type': 'input', 'desc': 'The instruction on which field to extact and how to extract it.', 'prefix': 'Instruction:', 'format': <function TypedPredictor._prepare_signature.<locals>.<lambda> at 0x1544cdc60>})\n error_tax_from_fields_0 = Field(annotation=str required=True json_schema_extra={'prefix': 'Past Error in Tax From Fields:', 'desc': 'An error to avoid in the future', '__dspy_field_type': 'input'})\n tax_from_fields = Field(annotation=TaxFromFields required=True description='The list of fields extracted from the document text. IMPORTANT!! This must follow a semicolon separated list of values!' json_schema_extra={'__dspy_field_type': 'output', 'desc': 'The list of fields extracted from the document text. IMPORTANT!! This must follow a semicolon separated list of values!. Respond with a single JSON object. JSON Schema: {\"$defs\": {\"ModelField\": {\"properties\": {\"key\": {\"title\": \"Key\", \"type\": \"string\"}, \"value\": {\"title\": \"Value\", \"type\": \"string\"}, \"truthful\": {\"description\": \"is a tuple contains the extracted value and a boolean indicating if the inFieldsation is truthful and correct.\", \"title\": \"Truthful\", \"type\": \"boolean\"}, \"confidence\": {\"description\": \"The confidence score for the answer\", \"maximum\": 1.0, \"minimum\": 0.0, \"title\": \"Confidence\", \"type\": \"number\"}}, \"required\": [\"key\", \"value\", \"truthful\", \"confidence\"], \"title\": \"ModelField\", \"type\": \"object\"}}, \"properties\": {\"tax_from\": {\"items\": {\"$ref\": \"#/$defs/ModelField\"}, \"title\": \"Tax From\", \"type\": \"array\"}}, \"required\": [\"tax_from\"], \"title\": \"TaxFromFields\", \"type\": \"object\"}', 'prefix': 'Tax From Fields:', 'format': <function TypedPredictor._prepare_signature.<locals>.<lambda> at 0x1544ce3e0>, 'parser': <function TypedPredictor._prepare_signature.<locals>.<lambda> at 0x1544ce5c0>})\n)"}
The output is which seems fine to me.
{"tax_from_fields": "```json\n{\n \"tax_from\": [\n {\n \"key\": \"Social Security Number\",\n \"value\": \"XXX-XX-2489\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"Employer's FED ID number\",\n \"value\": \"85-4006070\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"Employee's SSA number\",\n \"value\": \"XXX-XX-2489\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"Wages, tips, other comp.\",\n \"value\": \"2307.11\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"Federal income tax withheld\",\n \"value\": \"41.98\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"Social security wages\",\n \"value\": \"2129.03\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"Social security tax withheld\",\n \"value\": \"143.04\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"Medicare wages and tips\",\n \"value\": \"2307.11\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"Medicare tax withheld\",\n \"value\": \"33.45\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"State\",\n \"value\": \"AL\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"Employer's state ID no.\",\n \"value\": \"R010981778\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"State wages, tips, etc.\",\n \"value\": \"2307.11\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"State income tax\",\n \"value\": \"87.86\",\n \"truthful\": true,\n \"confidence\": 1.0\n },\n {\n \"key\": \"Local income tax\",\n \"value\": null,\n \"truthful\": false,\n \"confidence\": 0.0\n }\n ]\n}\n```"}
However, when look at what the model actually received as input some of the failed forward call seems to be having some kind of syntax problem.
Your task is to extract the key-value pairs from the document that will follow the instructions on the field to extract.
please keep the content you extract truthful and correct based on the document text provided
---
Follow the following format.
Document: ${document}
Instruction: The instruction on which field to extact and how to extract it.
Past Error in Tax From Fields: An error to avoid in the future
Tax From Fields: The list of fields extracted from the document text. IMPORTANT!! This must follow a semicolon separated list of values!. Respond with a single JSON object. JSON Schema: {"$defs": {"ModelField": {"properties": {"key": {"title": "Key", "type": "string"}, "value": {"title": "Value", "type": "string"}, "truthful": {"description": "is a tuple contains the extracted value and a boolean indicating if the inFieldsation is truthful and correct.", "title": "Truthful", "type": "boolean"}, "confidence": {"description": "The confidence score for the answer", "maximum": 1.0, "minimum": 0.0, "title": "Confidence", "type": "number"}}, "required": ["key", "value", "truthful", "confidence"], "title": "ModelField", "type": "object"}}, "properties": {"tax_from": {"items": {"$ref": "#/$defs/ModelField"}, "title": "Tax From", "type": "array"}}, "required": ["tax_from"], "title": "TaxFromFields", "type": "object"}
---
Document: 1 Wages, tips, other comp. 2307.11 2 Federal income tax withheld 41.98 3 Social security wages 2129.03 4 Social security tax withheld 143.04 5 Medicare wages and tips 2307.11 6 Medicare tax withheld 33.45 d Control number 0000034180 TM6 Dept. Corp. CTQF Employer use only 55394 c Employer's name, address, and ZIP code HUT AMERICAN GROUP LLC 6200 OAK TREE BLVD - SUITE 250 INDEPENDENCE, OH 44131 b Employer's FED ID number 85-4006070 a Employee's SSA number XXX-XX-2489 7 Social security tips 178.08 8 Allocated tips 9 10 Dependent care benefits 11 Nonqualified plans 12a DD | 69.71 14 Other 12b 12c 12d 13 Stat emp. Ret. plan 3rd party sick pay e/f Employee's name, address and ZIP code MARY A DAWKINS 1609 DEATSVILLE HWY MILLBROOK, AL 36054 15 State AL Employer's state ID no. R010981778 16 State wages, tips, etc. 2307.11 17 State income tax 87.86 18 Local wages, tips, etc. 19 Local income tax 20 Locality name City or Local Filing Copy W-2 Wage and Tax 2021 Statement OMB No. 1545-0008 Copy 2 to be filed with employee's City or Local Income Tax Return.
Instruction: extract the following fields (Semicolon separated): Social Security Number;Employer's FED ID number;Employee's SSA number;Wages, tips, other comp.;Federal income tax withheld;Social security wages;Social security tax withheld;Medicare wages and tips;Medicare tax withheld;State;Employer's state ID no.;State wages, tips, etc.;State income tax;Local income tax
Past Error in Tax From Fields: ValueError('Could not find valid json')
Tax From Fields:
Another one is
Your task is to extract the key-value pairs from the document that will follow the instructions on the field to extract.
please keep the content you extract truthful and correct based on the document text provided
---
Follow the following format.
Document: ${document}
Instruction: The instruction on which field to extact and how to extract it.
Past Error in Tax From Fields: An error to avoid in the future
Past Error (2) in Tax From Fields: An error to avoid in the future
Tax From Fields:
The list of fields extracted from the document text. IMPORTANT!! This must follow a semicolon separated list of values!. Respond with a single JSON object.
You MUST use this format: ```json
{"tax_from": [{"key": "tax_from", "value": "2022-10-10", "truthful": true, "confidence": 0.9}]}
JSON Schema: {"$defs": {"ModelField": {"properties": {"key": {"title": "Key", "type": "string"}, "value": {"title": "Value", "type": "string"}, "truthful": {"description": "is a tuple contains the extracted value and a boolean indicating if the inFieldsation is truthful and correct.", "title": "Truthful", "type": "boolean"}, "confidence": {"description": "The confidence score for the answer", "maximum": 1.0, "minimum": 0.0, "title": "Confidence", "type": "number"}}, "required": ["key", "value", "truthful", "confidence"], "title": "ModelField", "type": "object"}}, "properties": {"tax_from": {"items": {"$ref": "#/$defs/ModelField"}, "title": "Tax From", "type": "array"}}, "required": ["tax_from"], "title": "TaxFromFields", "type": "object"}
Document: 1 Wages, tips, other comp. 2307.11 2 Federal income tax withheld 41.98 3 Social security wages 2129.03 4 Social security tax withheld 143.04 5 Medicare wages and tips 2307.11 6 Medicare tax withheld 33.45 d Control number 0000034180 TM6 Dept. Corp. CTQF Employer use only 55394 c Employer's name, address, and ZIP code HUT AMERICAN GROUP LLC 6200 OAK TREE BLVD - SUITE 250 INDEPENDENCE, OH 44131 b Employer's FED ID number 85-4006070 a Employee's SSA number XXX-XX-2489 7 Social security tips 178.08 8 Allocated tips 9 10 Dependent care benefits 11 Nonqualified plans 12a DD | 69.71 14 Other 12b 12c 12d 13 Stat emp. Ret. plan 3rd party sick pay e/f Employee's name, address and ZIP code MARY A DAWKINS 1609 DEATSVILLE HWY MILLBROOK, AL 36054 15 State AL Employer's state ID no. R010981778 16 State wages, tips, etc. 2307.11 17 State income tax 87.86 18 Local wages, tips, etc. 19 Local income tax 20 Locality name City or Local Filing Copy W-2 Wage and Tax 2021 Statement OMB No. 1545-0008 Copy 2 to be filed with employee's City or Local Income Tax Return.
Instruction: extract the following fields (Semicolon separated): Social Security Number;Employer's FED ID number;Employee's SSA number;Wages, tips, other comp.;Federal income tax withheld;Social security wages;Social security tax withheld;Medicare wages and tips;Medicare tax withheld;State;Employer's state ID no.;State wages, tips, etc.;State income tax;Local income tax
Past Error in Tax From Fields: ValueError('Could not find valid json')
Past Error (2) in Tax From Fields: ValueError('Could not find valid json')
Tax From Fields:
All of the forward result that failed looks legit to me.
I am not sure if my use case is too complicated which doesn't seems to be to me. Any suggestions on how to make this work?
I am trying to use the TypePredictor to extract information from some inputs. The program failed due to exceeding number of retries but when inspecting the output of the LLM it seems to be in a parsable way.
The LLM used is gemini-1.5-pro
Code
when looking at the LLM input and output it seems it should be able to handle with it.
Output
Error after the retry exhaused