outlines-dev / outlines

Structured Text Generation
https://outlines-dev.github.io/outlines/
Apache License 2.0
8.16k stars 412 forks source link

JSON schema constraints not respected by outlines #654

Open streitl opened 6 months ago

streitl commented 6 months ago

Describe the issue as clearly as possible:

Hi! I'm trying to use outlines to generate Pydantic class instances where all of the fields are optional but at least one has to be set. I am enforcing these constraints in a Pydantic model validator and inside the json schema, but outlines does not respect the json schema constraints and the generated object fails Pydantic validation. When I test the json schema on online validators, it correctly disallows empty objects like {}, which is precisely what outlines tries to generate. Are some of my json schema constraints not supported yet, or is this a bug? Thank you!

Steps/code to reproduce the bug:

import outlines
import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator

class MyModel(BaseModel):
    a: int | None = Field(default=None)
    b: str | None = Field(default=None)

    model_config = ConfigDict(
        json_schema_extra={
            "anyOf": [
                {"properties": {field: {"not": {"type": "null"}}}, "required": [field]}
                for field in ["a", "b"]
            ]
        }
    )

    @model_validator(mode="after")
    def validate_my_model(self) -> "MyModel":
        if self.a is None and self.b is None:
            raise ValueError("Cannot have both a and b be None")
        return self

if __name__ == "__main__":
    device: str = "cpu"
    rng = torch.Generator(device=device)
    rng.manual_seed(90302)

    outlines_model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.2", device=device)
    generator = outlines.generate.json(outlines_model, MyModel, whitespace_pattern="")
    print(generator("Make an empty json object like this one: {}"))

Expected result:

An instance of MyModel where at least one of the fields is non-null.

Error message:

pydantic_core._pydantic_core.ValidationError: 1 validation error for MyModel
  Value error, Cannot have both a and b be None [type=value_error, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.6/v/value_error

Outlines/Python version information:

Version information

``` outlines=0.0.29 Python 3.11.7 (main, Dec 8 2023, 18:56:58) [GCC 11.4.0] accelerate==0.26.1 aiofiles==23.2.1 aiohttp==3.9.3 aiosignal==1.3.1 altair==5.2.0 annotated-types==0.6.0 anyio==4.2.0 asttokens==2.4.1 attrs==23.2.0 auto-gptq==0.6.0 bitsandbytes==0.42.0 black==23.12.1 certifi==2024.2.2 charset-normalizer==3.3.2 click==8.1.7 cloudpickle==3.0.0 coloredlogs==15.0.1 contourpy==1.2.0 cycler==0.12.1 datasets==2.17.0 decorator==5.1.1 dill==0.3.8 diskcache==5.6.3 docstring-parser==0.15 einops==0.7.0 executing==2.0.1 fastapi==0.109.2 ffmpy==0.3.2 filelock==3.13.1 flake8==6.1.0 fonttools==4.48.1 frozenlist==1.4.1 fsspec==2023.10.0 gekko==1.0.6 gradio==3.50.2 gradio_client==0.6.1 h11==0.14.0 httpcore==1.0.2 httpx==0.26.0 huggingface-hub==0.20.3 humanfriendly==10.0 idna==3.6 importlib-resources==6.1.1 interegular==0.3.3 ipdb==0.13.13 ipython==8.21.0 isort==5.13.2 jedi==0.19.1 jieba==0.42.1 Jinja2==3.1.3 joblib==1.3.2 jsonschema==4.21.1 jsonschema-specifications==2023.12.1 kiwisolver==1.4.5 lark==1.1.9 llama_cpp_python==0.2.41 llmtuner==0.5.1 llvmlite==0.42.0 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.8.2 matplotlib-inline==0.1.6 mccabe==0.7.0 mdurl==0.1.2 mpmath==1.3.0 multidict==6.0.5 multiprocess==0.70.16 mypy==1.8.0 mypy-extensions==1.0.0 nest-asyncio==1.6.0 networkx==3.2.1 nltk==3.8.1 numba==0.59.0 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.19.3 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 optimum==1.16.2 orjson==3.9.13 outlines==0.0.29 packaging==23.2 pandas==2.2.0 pandas-stubs==2.1.4.231227 parso==0.8.3 pathspec==0.12.1 peft==0.8.2 pexpect==4.9.0 pillow==10.2.0 platformdirs==4.2.0 prompt-toolkit==3.0.43 protobuf==4.25.2 psutil==5.9.8 psycopg==3.1.18 psycopg-binary==3.1.18 ptyprocess==0.7.0 pure-eval==0.2.2 pyarrow==15.0.0 pyarrow-hotfix==0.6 pycodestyle==2.11.1 pydantic==2.6.1 pydantic-2-mermaid==0.6.0 pydantic_core==2.16.2 pydub==0.25.1 pyflakes==3.1.0 Pygments==2.17.2 pyparsing==3.1.1 python-dateutil==2.8.2 python-multipart==0.0.9 pytz==2024.1 PyYAML==6.0.1 referencing==0.33.0 regex==2023.12.25 requests==2.31.0 rich==13.7.0 rouge==1.0.1 rouge-chinese==1.0.3 rpds-py==0.17.1 safetensors==0.4.2 scikit-learn==1.4.0 scipy==1.12.0 semantic-version==2.10.0 sentence-transformers==2.3.1 sentencepiece==0.1.99 shtab==1.6.5 six==1.16.0 sniffio==1.3.0 sse-starlette==2.0.0 stack-data==0.6.3 starlette==0.36.3 sympy==1.12 threadpoolctl==3.2.0 tokenizers==0.15.2 toolz==0.12.1 torch==2.2.0 tqdm==4.66.2 traitlets==5.14.1 transformers==4.37.2 triton==2.2.0 trl==0.7.10 types-pytz==2024.1.0.20240203 types-tqdm==4.66.0.20240106 typing_extensions==4.9.0 tyro==0.7.2 tzdata==2024.1 urllib3==2.2.0 uvicorn==0.27.1 wcwidth==0.2.13 websockets==11.0.3 xxhash==3.4.1 yarl==1.9.4 ```

Context for the issue:

No response

lapp0 commented 6 months ago

This is a bug with Outlines. It can't resolve coinciding anyOf and properties properly. If properties is present it ignores all other directives.

>>> print(json.dumps(MyModel.model_json_schema(), indent=2))
{
  "anyOf": [
    {
      "properties": {
        "a": {
          "not": {
            "type": "null"
          }
        }
      },
      "required": [
        "a"
      ]
    },
    {
      "properties": {
        "b": {
          "not": {
            "type": "null"
          }
        }
      },
      "required": [
        "b"
      ]
    }
  ],
  "properties": {
    "a": {
      "anyOf": [
        {
          "type": "integer"
        },
        {
          "type": "null"
        }
      ],
      "default": null,
      "title": "A"
    },
    "b": {
      "anyOf": [
        {
          "type": "string"
        },
        {
          "type": "null"
        }
      ],
      "default": null,
      "title": "B"
    }
  },
  "title": "MyModel",
  "type": "object"
}

For now I suggest you use

{
  "type": "object",
  "title": "MyModel",
  "anyOf": [
    {
      "type": "object",
      "required": ["a"],
      "properties": {
        "a": {"anyOf": [{"type": "integer"}]},
        "b": {"anyOf": [{"type": "string"}, {"type": "null"}]}
      } `
    },
    {
      "type": "object",
      "required": ["b"],
      "properties": {
        "b": {"anyOf": [{"type": "string"}]},
        "a": {"anyOf": [{"type": "integer"}, {"type": "null"}]}
      }
    }
  ]
}

For anyone who's looking to implement this, it involves &ing two patterns. I recommend using greenery via

combined_pattern = greenery.parse(properties_pattern) & greenery.parse(anyof_pattern)
str(combined_pattern)
streitl commented 6 months ago

Your answer helped me address the bug with a quick hack that uses a callable for pydantic's json_schema_extra, which generates something similar to your output (I need the json schema to be dynamic since I'm changing my pydantic models quite often)

from pydantic.config import JsonDict

def force_at_least_one_field_to_exist(schema: JsonDict) -> None:
    properties: JsonDict = schema.pop("properties")  # type: ignore
    schema |= {
        "anyOf": [
            {
                "properties": {
                    field_name: {**_type, **rest},  # type: ignore
                    **other_fields,
                },
                "required": [field_name],
            }
            for field_name in properties.keys()
            for other_fields in [
                {k: v for k, v in properties.items() if k != field_name}
            ]
            for _type in [
                el for el in properties[field_name]["anyOf"] if el.get("type") != "null"  # type: ignore
            ]
            for rest in [
                {k: v for k, v in properties[field_name].items() if k != "anyOf"}  # type: ignore
            ]
        ]
    }

class MyModel(BaseModel):
    ...
    model_config = ConfigDict(json_schema_extra=force_at_least_one_field_to_exist)
    ...

I hope this will help somebody :smile: Thank you for your help!