vitalik / django-ninja

💨 Fast, Async-ready, Openapi, type hints based framework for building APIs
https://django-ninja.dev
MIT License
6.9k stars 418 forks source link

[BUG] Nested Enum Classes With Duplicate Names #537

Open OtherBarry opened 2 years ago

OtherBarry commented 2 years ago

Problem It's a pretty common pattern with Django classes to contain enums/choices classes as a nested class within the model class that uses the enum. If there are multiple nested classes with the same name, even though they are nested under different classes, the generated OpenAPI schema only uses one of them.

Using __qualname__ instead of __name__ should solve this issue for nested classes.

It might be worth looking into namespacing schemas by prefixing them with their module or django app or something, as I imagine this issue occurs with any duplicated names.

Example

class SomeSchema(Schema):
    class Status(str, Enum):
        FOO = "foo"
    status: SomeSchema.Status

class OtherSchema(Schema) 
    class Status(str, Enum):
        BAR= "bar"
    status: OtherSchema.Status

@api.post("/some")
def get_some_thing(request, data: SomeSchema):
    return data.status

@api.post("/other")
def get_other_thing(request, data: OtherSchema):
    return data.status

Only one of the status enums will be present in the resulting schema.

Versions

jmduke commented 1 year ago

@OtherBarry did you figure out a workaround for this? I'm running into precisely this problem and can't find a way to even monkeypatch it.

OtherBarry commented 1 year ago

@jmduke I ended up just renaming my enums, so instead of Order.Status and Product.Status I have Order.OrderStatus and Product.ProductStatus, which solved the problem for me.

The __qualname__ fix is a pretty easy solution, but might be a big of a pain to monkeypatch. I believe the relevant function is here.

jmduke commented 1 year ago

Thanks! Because I'm too stubborn (and the monkeypatching only happens offline, since I'm not serving the schema dynamically) I went with that approach. Your pointer just changes the title of the object, not the ref ID; monkey-patching get_model_name_map in pydantic.schema ended up doing the trick.

OtherBarry commented 1 year ago

Awesome! Are you able to make a pull request in pydantic with the change? Or just post your monkeypatch here and I'll look into it.

OtherBarry commented 1 year ago

Took a look at the pydantic code and it seems like they actually handle name collisions reasonably well. The issue is that django-ninja uses model_schema() on each model, instead of using schema() on all models, so collisions aren't handled.

@vitalik is this something that's a relatively easy fix? I don't know the schema generation code at all so will take me a while to find out.

@jmduke can you post your monkeypatch here so that other people with this issue (namely me) can use it in the mean time?

jmduke commented 1 year ago

@OtherBarry you beat me to the explanation :) There's a couple todos here that I think feint towards the issue, but I couldn't quite wrap my head around the indirection that goes on in this module. I agree that a better approach for django-ninja to take might be delegate as much of the mapping + conflict resolution as possible to pydantic, which appears to handle it quite well.

To answer your question, though, the monkey-patch in question:

from pydantic import schema

def monkey_patched_get_model_name_map(
    unique_models: schema.TypeModelSet,
) -> dict[schema.TypeModelOrEnum, str]:
    """
    Process a set of models and generate unique names for them to be used as keys in the JSON Schema
    definitions. By default the names are the same as the class name. But if two models in different Python
    modules have the same name (e.g. "users.Model" and "items.Model"), the generated names will be
    based on the Python module path for those conflicting models to prevent name collisions.
    :param unique_models: a Python set of models
    :return: dict mapping models to names
    """
    name_model_map = {}
    conflicting_names: set[str] = set()
    for model in unique_models:
        model_name = schema.normalize_name(model.__qualname__.replace(".", ""))
        if model_name in conflicting_names:
            model_name = schema.get_long_model_name(model)
            name_model_map[model_name] = model
        elif model_name in name_model_map:
            conflicting_names.add(model_name)
            conflicting_model = name_model_map.pop(model_name)
            name_model_map[
                schema.get_long_model_name(conflicting_model)
            ] = conflicting_model
            name_model_map[schema.get_long_model_name(model)] = model
        else:
            name_model_map[model_name] = model
    return {v: k for k, v in name_model_map.items()}

schema.get_model_name_map = monkey_patched_get_model_name_map

The only line here that changes from the original is:

model_name = schema.normalize_name(model.__qualname__.replace(".", ""))

A couple notes:

jmduke commented 9 months ago

@vitalik You mentioned in https://github.com/vitalik/django-ninja/issues/862 that this is no longer possible in pydantic2, and indeed after trying to migrate my setup to django-ninja@1.0rc I run into the issue as outlined in https://github.com/vitalik/django-ninja/issues/537. Is there a recommended path forward? This is a blocker for me, and I imagine it's not a particularly uncommon use case.

furious-luke commented 3 months ago

Edit: This patch is only partially correct, please see my next comment in addition to this one.

Hey @jmduke! I'm not sure if you're still blocked by this, but I came across the same issue in my work and put together a quick monkeypatch to temporarily resolve the issue. I'll add the monkey patch below.

In my case, I only care about resolving clashing names for nested TextChoices on my models. To that end, I've only patched the function related to generating titles for enumerations. Similar to the original patch, there is only one line changed from the original Pydantic function (it's the line containing the __qualname__ access).

Anyhow, I hope you find it useful!

from enum import Enum
import inspect
from operator import attrgetter
from typing import Any, Literal

from pydantic import ConfigDict
from pydantic_core import core_schema, CoreSchema
from pydantic.json_schema import JsonSchemaValue
import pydantic._internal._std_types_schema
from pydantic._internal._core_utils import get_type_ref
from pydantic._internal._schema_generation_shared import GetJsonSchemaHandler

def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
    cases: list[Any] = list(enum_type.__members__.values())

    enum_ref = get_type_ref(enum_type)
    description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
    if description == 'An enumeration.':  # This is the default value provided by enum.EnumMeta.__new__; don't use it
        description = None
    js_updates = {'title': enum_type.__qualname__.replace(".", ""), 'description': description}
    js_updates = {k: v for k, v in js_updates.items() if v is not None}

    sub_type: Literal['str', 'int', 'float'] | None = None
    if issubclass(enum_type, int):
        sub_type = 'int'
        value_ser_type: core_schema.SerSchema = core_schema.simple_ser_schema('int')
    elif issubclass(enum_type, str):
        # this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
        sub_type = 'str'
        value_ser_type = core_schema.simple_ser_schema('str')
    elif issubclass(enum_type, float):
        sub_type = 'float'
        value_ser_type = core_schema.simple_ser_schema('float')
    else:
        # TODO this is an ugly hack, how do we trigger an Any schema for serialization?
        value_ser_type = core_schema.plain_serializer_function_ser_schema(lambda x: x)

    if cases:

        def get_json_schema(schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
            json_schema = handler(schema)
            original_schema = handler.resolve_ref_schema(json_schema)
            original_schema.update(js_updates)
            return json_schema

        # we don't want to add the missing to the schema if it's the default one
        default_missing = getattr(enum_type._missing_, '__func__', None) == Enum._missing_.__func__  # type: ignore
        enum_schema = core_schema.enum_schema(
            enum_type,
            cases,
            sub_type=sub_type,
            missing=None if default_missing else enum_type._missing_,
            ref=enum_ref,
            metadata={'pydantic_js_functions': [get_json_schema]},
        )

        if config.get('use_enum_values', False):
            enum_schema = core_schema.no_info_after_validator_function(
                attrgetter('value'), enum_schema, serialization=value_ser_type
            )

        return enum_schema

    else:

        def get_json_schema_no_cases(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
            json_schema = handler(core_schema.enum_schema(enum_type, cases, sub_type=sub_type, ref=enum_ref))
            original_schema = handler.resolve_ref_schema(json_schema)
            original_schema.update(js_updates)
            return json_schema

        # Use an isinstance check for enums with no cases.
        # The most important use case for this is creating TypeVar bounds for generics that should
        # be restricted to enums. This is more consistent than it might seem at first, since you can only
        # subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases.
        # We use the get_json_schema function when an Enum subclass has been declared with no cases
        # so that we can still generate a valid json schema.
        return core_schema.is_instance_schema(
            enum_type,
            metadata={'pydantic_js_functions': [get_json_schema_no_cases]},
        )

pydantic._internal._std_types_schema.get_enum_core_schema = get_enum_core_schema
furious-luke commented 3 months ago

After some more testing I found the above patch fails to correct the JSON schema definition refs. The code in Pydantic that generates, caches, and uses, these references is pretty complicated, so I'm sure there's a better way, but I've made another monkey patch to resolve the ref issue, too.

The solution I've used is rearranging the preferential order of the reference identifiers generated by Pydantic to preference the most specific option. It'll result in ugly refs, ~but those never actually get presented to the user anywhere I think, so shouldn't be too impactful~ no, they do actually appear in the JSON schema when downloaded. In my particular case it won't cause any issues, it's just ugly.

Anyway, here's the full patch, including the above one, and the additional function to correct the refs:

import re
from enum import Enum
import inspect
from operator import attrgetter
from typing import Any, Literal

from pydantic import ConfigDict
from pydantic_core import core_schema, CoreSchema
from pydantic.json_schema import JsonSchemaValue, CoreModeRef, DefsRef, _MODE_TITLE_MAPPING
import pydantic._internal._std_types_schema
from pydantic._internal._core_utils import get_type_ref
from pydantic._internal._schema_generation_shared import GetJsonSchemaHandler

def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
    cases: list[Any] = list(enum_type.__members__.values())

    enum_ref = get_type_ref(enum_type)
    description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
    if description == 'An enumeration.':  # This is the default value provided by enum.EnumMeta.__new__; don't use it
        description = None
    js_updates = {'title': enum_type.__qualname__.replace(".", ""), 'description': description}
    js_updates = {k: v for k, v in js_updates.items() if v is not None}

    sub_type: Literal['str', 'int', 'float'] | None = None
    if issubclass(enum_type, int):
        sub_type = 'int'
        value_ser_type: core_schema.SerSchema = core_schema.simple_ser_schema('int')
    elif issubclass(enum_type, str):
        # this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
        sub_type = 'str'
        value_ser_type = core_schema.simple_ser_schema('str')
    elif issubclass(enum_type, float):
        sub_type = 'float'
        value_ser_type = core_schema.simple_ser_schema('float')
    else:
        # TODO this is an ugly hack, how do we trigger an Any schema for serialization?
        value_ser_type = core_schema.plain_serializer_function_ser_schema(lambda x: x)

    if cases:

        def get_json_schema(schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
            json_schema = handler(schema)
            original_schema = handler.resolve_ref_schema(json_schema)
            original_schema.update(js_updates)
            return json_schema

        # we don't want to add the missing to the schema if it's the default one
        default_missing = getattr(enum_type._missing_, '__func__', None) == Enum._missing_.__func__  # type: ignore
        enum_schema = core_schema.enum_schema(
            enum_type,
            cases,
            sub_type=sub_type,
            missing=None if default_missing else enum_type._missing_,
            ref=enum_ref,
            metadata={'pydantic_js_functions': [get_json_schema]},
        )

        if config.get('use_enum_values', False):
            enum_schema = core_schema.no_info_after_validator_function(
                attrgetter('value'), enum_schema, serialization=value_ser_type
            )

        return enum_schema

    else:

        def get_json_schema_no_cases(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
            json_schema = handler(core_schema.enum_schema(enum_type, cases, sub_type=sub_type, ref=enum_ref))
            original_schema = handler.resolve_ref_schema(json_schema)
            original_schema.update(js_updates)
            return json_schema

        # Use an isinstance check for enums with no cases.
        # The most important use case for this is creating TypeVar bounds for generics that should
        # be restricted to enums. This is more consistent than it might seem at first, since you can only
        # subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases.
        # We use the get_json_schema function when an Enum subclass has been declared with no cases
        # so that we can still generate a valid json schema.
        return core_schema.is_instance_schema(
            enum_type,
            metadata={'pydantic_js_functions': [get_json_schema_no_cases]},
        )

pydantic._internal._std_types_schema.get_enum_core_schema = get_enum_core_schema

def get_defs_ref(self, core_mode_ref: CoreModeRef) -> DefsRef:
    """Override this method to change the way that definitions keys are generated from a core reference.

    Args:
        core_mode_ref: The core reference.

    Returns:
        The definitions key.
    """
    # Split the core ref into "components"; generic origins and arguments are each separate components
    core_ref, mode = core_mode_ref
    components = re.split(r'([\][,])', core_ref)
    # Remove IDs from each component
    components = [x.rsplit(':', 1)[0] for x in components]
    core_ref_no_id = ''.join(components)
    # Remove everything before the last period from each "component"
    components = [re.sub(r'(?:[^.[\]]+\.)+((?:[^.[\]]+))', r'\1', x) for x in components]
    short_ref = ''.join(components)

    mode_title = _MODE_TITLE_MAPPING[mode]

    # It is important that the generated defs_ref values be such that at least one choice will not
    # be generated for any other core_ref. Currently, this should be the case because we include
    # the id of the source type in the core_ref
    name = DefsRef(self.normalize_name(short_ref))
    name_mode = DefsRef(self.normalize_name(short_ref) + f'-{mode_title}')
    module_qualname = DefsRef(self.normalize_name(core_ref_no_id))
    module_qualname_mode = DefsRef(f'{module_qualname}-{mode_title}')
    module_qualname_id = DefsRef(self.normalize_name(core_ref))
    occurrence_index = self._collision_index.get(module_qualname_id)
    if occurrence_index is None:
        self._collision_counter[module_qualname] += 1
        occurrence_index = self._collision_index[module_qualname_id] = self._collision_counter[module_qualname]

    module_qualname_occurrence = DefsRef(f'{module_qualname}__{occurrence_index}')
    module_qualname_occurrence_mode = DefsRef(f'{module_qualname_mode}__{occurrence_index}')

    self._prioritized_defsref_choices[module_qualname_occurrence_mode] = [
        module_qualname_occurrence_mode,
        name,
        name_mode,
        module_qualname,
        module_qualname_mode,
        module_qualname_occurrence,
    ]

    return module_qualname_occurrence_mode

pydantic.json_schema.GenerateJsonSchema.get_defs_ref = get_defs_ref