pydantic / pydantic-settings

Settings management using pydantic
https://docs.pydantic.dev/latest/usage/pydantic_settings/
MIT License
591 stars 61 forks source link

Discriminated unions with callable discriminator don't work correctly. #284

Closed jenskeiner closed 5 months ago

jenskeiner commented 5 months ago

Dear maintainers,

I believe there may be a bug in how pydantic_settings interprets model fields that are discriminated unions with a callable discriminator.

Consider the following self-contained example:

import os
from typing import Literal, Union, Any, Annotated

from pydantic import BaseModel, Tag, Discriminator
from pydantic_settings import BaseSettings, SettingsConfigDict

class A(BaseModel):
    x: Literal['a'] = 'a'
    y: str

class B(BaseModel):
    x: Literal['b'] = 'b'
    z: str

def get_discriminator_value(v: Any) -> Union[Literal['a'], Literal['b']]:
    if isinstance(v, dict):
        v0 = v.get("x")
    else:
        v0 = getattr(v, "x", None)

    if v0 == 'a':
        return "a"
    elif v0 == "b":
        return "b"
    else:
        return None

class S(BaseSettings):
    model_config = SettingsConfigDict(env_prefix="SETTINGS_", env_nested_delimiter="__")
    a_or_b: Annotated[Union[Annotated[A, Tag('a')], Annotated[B, Tag('b')]], Discriminator(get_discriminator_value)]

# from pydantic import Json
# import pydantic_settings.sources
# from typing_extensions import _AnnotatedAlias, get_args, get_origin
#
# def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
#     if any(isinstance(md, Json) for md in metadata):  # type: ignore[misc]
#         return False
#     if isinstance(annotation, _AnnotatedAlias):
#         arg, meta = get_args(annotation)
#         return _annotation_is_complex(arg, [meta])
#     origin = get_origin(annotation)
#     return (
#         pydantic_settings.sources._annotation_is_complex_inner(annotation)
#         or pydantic_settings.sources._annotation_is_complex_inner(origin)
#         or hasattr(origin, "__pydantic_core_schema__")
#         or hasattr(origin, "__get_pydantic_core_schema__")
#     )
#
#
# # Overwrite the internal method in pydantic_settings.
# # TODO: Remove hack once the original issue has been fixed in the upstream package.
# pydantic_settings.sources._annotation_is_complex = _annotation_is_complex

if __name__ == "__main__":
    os.environ['SETTINGS_A_OR_B__X'] = "a"
    os.environ['SETTINGS_A_OR_B__Y'] = "foo"
    s = S()
    assert s.a_or_b.y == "foo"

The settings model S uses a field a_or_b that is defined as a discriminated union. For certain reasons, one may want to use a callable discriminator as in this case. When initializing the settings instance, the environment has been prepared such that the variable SETTINGS_A_OR_B__X has a value a that should cause the model instance to set a_or_b to an instance of the model class A. Also, the variable SETTINGS_A_OR_B__Y should initialize the value of the field y on the instance of A to the value "foo".

When I execute the code using pydantic 2.7.1 and pydantic_settings 2.2.1, I get an error since the field a_or_b is not interpreted as a complex field (e.g. a sub-model) which it is in fact. Consequently, the environment is searched for a variable named SETTINGS_A_OR_B which is not set.

If the section of the code that is commented out is activated, the relevant internal function is modified to perform a recursive call on Annotated[...] types. In that case, the code runs as expected.

Can you confirm this is a bug? Is the modified internal function _annotation_is_complex from above a valid fix?

Happy to contribute this in a PR, if confirmed.

Best,

Jens

hramezani commented 5 months ago

Thanks @jenskeiner for reporting the issue and providing the fix.

The fix looks good. would be great to create a PR.

jenskeiner commented 5 months ago

Hi @hramezani, PR #285 is ready for review.