FullFact / health-misinfo-shared

Raphael health misinformation project, shared by Full Fact and Google
MIT License
0 stars 0 forks source link

Repeat claim extraction when LLM output is wrong #140

Closed c-j-johnston closed 1 week ago

c-j-johnston commented 2 weeks ago

Overview

Sometimes, Gemini outputs are incorrect. They miss out a key in the JSON output (usually original_text) and this can cause the app to crash. However, this is rare, so allowing up to two retries in llm.py should prevent this from being a problem. On the third failure, only the failing claims should be returned.

Requirements

  1. An assertion function is created to ensure LLM output JSONs have the correct keys.
  2. If the assertion fails for any output claim JSON, the LLM runs over the whole chunk again.
  3. If the assertion fails a third time, we return a null on the offending value.
  4. We test this.

Notes and additional information

dearden commented 2 weeks ago

On point 1: I wrote some assertions for checking format over in evaluation

Although, if we had a Claim dataclass with a from_dict function, it would be a lot easier to catch than it is with the dict

c-j-johnston commented 2 weeks ago

This is probably bad practice, but here's a WIP assertion function. Keeping here because it did not end up in a branch, and it was going to get lost in a merge. Sorry.

def assert_output_json_format(output: dict[str, str]) -> bool:
    passed = True
    if not all([k in output.keys() for k in ["claim", "original_text", "labels"]]):
        passed = False
    else:
        labels = output.get("labels", {})
        if not all(
            [
                k in labels.keys()
                for k in [
                    "understandability",
                    "type_of_claim",
                    "type_of_medical_claim",
                    "support",
                    "harm",
                ]
            ]
        ):
            passed = False
    if passed == False:
        print(output)
    return passed
andylolz commented 2 weeks ago

Per @dearden’s suggestion, here is a similar sort of thing, but using pydantic (which we appear to have as a dependency anyway):

from pydantic import BaseModel, ValidationError

class LabelsModel(BaseModel):
    understandability: str
    type_of_claim: str
    type_of_medical_claim: str
    support: str
    harm: str

class ClaimModel(BaseModel):
    claim: str
    original_text: str
    labels: LabelsModel

def assert_output_json_format(output: dict) -> bool:
    try:
        ClaimModel(**output)
    except ValidationError:
        print(output)
        return False
    return True