sbdchd / django-types

:doughnut: Type stubs for Django
MIT License
202 stars 63 forks source link

Allow to limit charfield to the defined choices #44

Closed bellini666 closed 3 years ago

bellini666 commented 3 years ago

Fix #34

I made those changes in my personal project and it worked, but I don't really know how to run the tests in this package.

sbdchd commented 3 years ago

Thanks for the PR! For running these tests, if you install the dependencies in a .venv then s/lint should run the linters (also need to install pyright via npm/yarn)

I messed around this PR a bit and got it passing the linter via:

diff --git a/tests/trout/models.py b/tests/trout/models.py
index 13ffceb..1c595ed 100644
--- a/tests/trout/models.py
+++ b/tests/trout/models.py
@@ -3,7 +3,7 @@ from collections import namedtuple
 from datetime import time, timedelta
 from decimal import Decimal
 from io import StringIO
-from typing import Any, Dict, List, Literal, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 from uuid import UUID

 import psycopg2
@@ -40,6 +40,7 @@ from django.views.decorators.http import (
 from psycopg2 import ProgrammingError, sql
 from psycopg2.extensions import parse_dsn
 from psycopg2.extras import execute_values
+from typing_extensions import Literal

 class User(models.Model):
@@ -84,6 +85,13 @@ class Comment(models.Model):

     char = models.CharField()
     char_nullable = models.CharField(null=True)
     char_with_choices = models.CharField[Literal["foo", "bar"]](
         choices=[
             ("foo", "Foo"),
@@ -235,6 +243,10 @@ def process_non_nullable(
     ...

+def check_with_literal(x: Literal["foo", "bar"]) -> None:
+    ...
+
+
 def main() -> None:

     client = Client()
@@ -265,6 +277,11 @@ def main() -> None:
     comment.id = None
     comment.save()

+    reveal_type(comment.char_with_choices)
+
+    check_with_literal(comment.char_with_choices)
+    check_with_literal(comment.ci_char_with_choices)
+
     print(comment.id)

     process_non_nullable(comment.post_fk)
diff --git a/typings/django/contrib/admin/decorators.pyi b/typings/django/contrib/admin/decorators.pyi
index 4ee2c35..a55802e 100644
--- a/typings/django/contrib/admin/decorators.pyi
+++ b/typings/django/contrib/admin/decorators.pyi
@@ -9,7 +9,9 @@ from django.http import HttpRequest, HttpResponse
 _M = TypeVar("_M", bound=Model)

 def action(
-    function: Optional[Callable[[ModelAdmin[_M], HttpRequest, QuerySet[_M]], Optional[HttpResponse]]] = ...,
+    function: Optional[
+        Callable[[ModelAdmin[_M], HttpRequest, QuerySet[_M]], Optional[HttpResponse]]
+    ] = ...,
     *,
     permissions: Optional[Sequence[str]] = ...,
     description: Optional[str] = ...,
diff --git a/typings/django/contrib/postgres/fields/citext.pyi b/typings/django/contrib/postgres/fields/citext.pyi
index c7f50bf..37fc7f0 100644
--- a/typings/django/contrib/postgres/fields/citext.pyi
+++ b/typings/django/contrib/postgres/fields/citext.pyi
@@ -76,7 +76,7 @@ class CICharField(CIText, CharField[_C]):
         error_messages: Optional[_ErrorMessagesToOverride] = ...,
     ) -> None: ...
     def __get__(self: CICharField[_C], instance: Any, owner: Any) -> _C: ...  # type: ignore [override]
-    def __set__(self: CICharField[_C], instance: Any, value: _C) -> None: ...  # type: ignore [override]
+    def __set__(self: CICharField[_C], instance: Any, value: _C) -> None: ...

 class CIEmailField(CIText, EmailField[_C]):
     @overload
@@ -130,7 +130,7 @@ class CIEmailField(CIText, EmailField[_C]):
         error_messages: Optional[_ErrorMessagesToOverride] = ...,
     ) -> None: ...
     def __get__(self: CIEmailField[_C], instance: Any, owner: Any) -> _C: ...  # type: ignore [override]
-    def __set__(self, instance: Any, value: _C) -> None: ...  # type: ignore [override]
+    def __set__(self, instance: Any, value: _C) -> None: ...

 class CITextField(CIText, TextField[_C]):
     @overload
diff --git a/typings/django/db/models/__init__.pyi b/typings/django/db/models/__init__.pyi
index dac446d..de71b40 100644
--- a/typings/django/db/models/__init__.pyi
+++ b/typings/django/db/models/__init__.pyi
@@ -67,9 +67,9 @@ from .fields import FloatField as FloatField
 from .fields import GenericIPAddressField as GenericIPAddressField
 from .fields import IntegerField as IntegerField
 from .fields import IPAddressField as IPAddressField
+from .fields import PositiveBigIntegerField as PositiveBigIntegerField
 from .fields import PositiveIntegerField as PositiveIntegerField
 from .fields import PositiveSmallIntegerField as PositiveSmallIntegerField
-from .fields import PositiveBigIntegerField as PositiveBigIntegerField
 from .fields import SlugField as SlugField
 from .fields import SmallIntegerField as SmallIntegerField
 from .fields import TextField as TextField
diff --git a/typings/django/db/models/fields/__init__.pyi b/typings/django/db/models/fields/__init__.pyi
index 1276468..438d584 100644
--- a/typings/django/db/models/fields/__init__.pyi
+++ b/typings/django/db/models/fields/__init__.pyi
@@ -617,14 +617,15 @@ _C = TypeVar("_C", bound="Optional[str]")
 class CharField(Generic[_C], Field[_C, _C]):
     @overload
     def __init__(
-        self: CharField[str],
+        self: CharField[_C],
+        *,
         verbose_name: Optional[Union[str, bytes]] = ...,
         name: Optional[str] = ...,
         primary_key: bool = ...,
         max_length: Optional[int] = ...,
         unique: bool = ...,
         blank: bool = ...,
-        null: Literal[False] = ...,
+        null: bool = ...,
         db_index: bool = ...,
         default: Any = ...,
         editable: bool = ...,
@@ -633,7 +634,7 @@ class CharField(Generic[_C], Field[_C, _C]):
         unique_for_date: Optional[str] = ...,
         unique_for_month: Optional[str] = ...,
         unique_for_year: Optional[str] = ...,
-        choices: Optional[_FieldChoices] = ...,
+        choices: Optional[_FieldChoices],
         help_text: str = ...,
         db_column: Optional[str] = ...,
         db_tablespace: Optional[str] = ...,
@@ -642,7 +643,8 @@ class CharField(Generic[_C], Field[_C, _C]):
     ) -> None: ...
     @overload
     def __init__(
-        self: CharField[_C],
+        self: CharField[str],
+        *,
         verbose_name: Optional[Union[str, bytes]] = ...,
         name: Optional[str] = ...,
         primary_key: bool = ...,
@@ -668,38 +670,14 @@ class CharField(Generic[_C], Field[_C, _C]):
     @overload
     def __init__(
         self: CharField[Optional[str]],
+        *,
         verbose_name: Optional[Union[str, bytes]] = ...,
         name: Optional[str] = ...,
         primary_key: bool = ...,
         max_length: Optional[int] = ...,
         unique: bool = ...,
         blank: bool = ...,
-        null: Literal[True] = ...,
-        db_index: bool = ...,
-        default: Any = ...,
-        editable: bool = ...,
-        auto_created: bool = ...,
-        serialize: bool = ...,
-        unique_for_date: Optional[str] = ...,
-        unique_for_month: Optional[str] = ...,
-        unique_for_year: Optional[str] = ...,
-        choices: Optional[_FieldChoices] = ...,
-        help_text: str = ...,
-        db_column: Optional[str] = ...,
-        db_tablespace: Optional[str] = ...,
-        validators: Iterable[_ValidatorCallable] = ...,
-        error_messages: Optional[_ErrorMessagesToOverride] = ...,
-    ) -> None: ...
-    @overload
-    def __init__(
-        self: CharField[Optional[_C]],
-        verbose_name: Optional[Union[str, bytes]] = ...,
-        name: Optional[str] = ...,
-        primary_key: bool = ...,
-        max_length: Optional[int] = ...,
-        unique: bool = ...,
-        blank: bool = ...,
-        null: Literal[True] = ...,
+        null: Literal[True],
         db_index: bool = ...,
         default: Any = ...,
         editable: bool = ...,
@@ -716,8 +694,7 @@ class CharField(Generic[_C], Field[_C, _C]):
         error_messages: Optional[_ErrorMessagesToOverride] = ...,
     ) -> None: ...
     def __get__(self: CharField[_C], instance: Any, owner: Any) -> _C: ...  # type: ignore [override]
-    def __get__(self: CharField[_C], instance: Any, owner: Any) -> _C: ...  # type: ignore [override]
-    def __set__(self: CharField[_C], instance: Any, value: _C) -> None: ...  # type: ignore [override]
+    def __set__(self: CharField[_C], instance: Any, value: _C) -> None: ...

 class SlugField(CharField[_C]):
     @overload
@@ -773,7 +750,7 @@ class SlugField(CharField[_C]):
         error_messages: Optional[_ErrorMessagesToOverride] = ...,
     ) -> None: ...
     def __get__(self: SlugField[_C], instance: Any, owner: Any) -> _C: ...  # type: ignore [override]
-    def __set__(self, instance: Any, value: _C) -> None: ...  # type: ignore [override]
+    def __set__(self, instance: Any, value: _C) -> None: ...

 class EmailField(CharField[_C]):
     @overload
@@ -827,7 +804,7 @@ class EmailField(CharField[_C]):
         error_messages: Optional[_ErrorMessagesToOverride] = ...,
     ) -> None: ...
     def __get__(self, instance: Any, owner: Any) -> _C: ...  # type: ignore [override]
-    def __set__(self, instance: Any, value: _C) -> None: ...  # type: ignore [override]
+    def __set__(self, instance: Any, value: _C) -> None: ...

 class URLField(CharField[_C]):
     @overload
@@ -881,7 +858,7 @@ class URLField(CharField[_C]):
         error_messages: Optional[_ErrorMessagesToOverride] = ...,
     ) -> None: ...
     def __get__(self, instance: Any, owner: Any) -> _C: ...  # type: ignore [override]
-    def __set__(self, instance: Any, value: _C) -> None: ...  # type: ignore [override]
+    def __set__(self, instance: Any, value: _C) -> None: ...

 class TextField(Generic[_C], Field[str, str]):
     @overload

But a thing I noticed is an existing field like:

char_with_choices_no_type = models.CharField(
     null=True,
    choices=[
        ("foo", "Foo"),
        ("bar", "Bar"),
      ],
)

will result in an error:

tests/trout/models.py:88:33: error: Need type annotation for 'char_with_choices_no_type'  [var-annotated]

whereas before it would be typed with __get__ and __set__ that take and return strings

Related, does providing choices prevent other values from being serialized? I think if we update the data in the database directly to be a value not defined in the choices, and then fetch the record from the DB, we'd end up with a value not in the type so we'd be risking a runtime error

sbdchd commented 3 years ago

Another thought, we could define a custom field called EnumField or something with an API like:

class FooType(enum.Enum):
    Foo = "Foo"
    Bar = "Bar"

class FooModel(models.Model):
    foo_field = models.EnumField(
        null=True,
        enum=FooType
     )

We'd follow the usual setup outlined in the docs for custom fields: https://docs.djangoproject.com/en/3.2/howto/custom-model-fields/

and then we could create some typing specific stuff if necessary for things like __get__ and __set__ by putting the methods in a if TYPE_CHECKING block so only the type checker sees it and it doesn't interfere with anything at runtime

https://mypy.readthedocs.io/en/stable/runtime_troubles.html?highlight=TYPE_CHECKING#typing-type-checking

Another possible option for an APII:

FooType = Literal["Foo", "Bar"]

class FooModel(models.Model):
    foo_field = models.LiteralField(
        null=True,
        options=FooType
     )
bellini666 commented 3 years ago

Hey @sbdchd ,

Answering your questions, the db could retrieve something that doesn't match the choices, but it is usually up to the application to avoid those issues. When cleaning an object that has choices for example it will validate that the value is valid. The same goes to someone modifying a value that fails the validators defined in the model.

Regarding your suggestions, I'm already trying to do something like this:

class ChoicesField(models.CharField):
    description = "Choices"
    default_error_messages = {
        "invalid": "“%(value)s” must be a subclass of %(enum)s.",
    }

    def __init__(self, verbose_name=None, name=None, choices_enum=None, **kwargs):
        self.choices_enum = choices_enum
        kwargs["choices"] = choices_enum.choices
        kwargs["max_length"] = max(len(c.value) for c in choices_enum)
        super().__init__(verbose_name=verbose_name, name=name, **kwargs)

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        kwargs["choices_enum"] = self.choices_enum
        return name, path, args, kwargs

    def to_python(self, value):
        if value is None:
            return None

        if isinstance(value, self.choices_enum):
            return value

        try:
            return self.choices_enum(value)
        except ValueError:
            raise ValidationError(
                self.error_messages["invalid"],
                code="invalid",
                params={"value": value, "enum": self.choices_enum},
            )

    def from_db_value(self, value, expression, connection):
        return self.to_python(value)

    def get_prep_value(self, value):
        value = super().get_prep_value(value)
        return value.value if value is not None else None

This would work just like a CharField with choices, but making sure we get/set enums instead of strings. Not only that allows us to work with less "magic strings" but also the typing support should be better than the literal alternative.

I want to try to type that and release as a pypi package.

I'm doing some tests and I'm unfortunately not really sure on how to type that properly. I created a .pyi file for it like this:

_C = TypeVar("_C", bound="Optional[TextChoices]")

class ChoicesField(Generic[_C], Field[_C, _C]):
    @overload
    def __new__(
        cls,
        choices_enum: Type[_C],
        verbose_name: Optional[Union[str, bytes]] = ...,
        name: Optional[str] = ...,
        primary_key: bool = ...,
        max_length: Optional[int] = ...,
        unique: bool = ...,
        blank: bool = ...,
        null: Literal[False] = ...,
        db_index: bool = ...,
        default: Any = ...,
        editable: bool = ...,
        auto_created: bool = ...,
        serialize: bool = ...,
        unique_for_date: Optional[str] = ...,
        unique_for_month: Optional[str] = ...,
        unique_for_year: Optional[str] = ...,
        choices: Optional[_FieldChoices] = ...,
        help_text: str = ...,
        db_column: Optional[str] = ...,
        db_tablespace: Optional[str] = ...,
        validators: Iterable[_ValidatorCallable] = ...,
        error_messages: Optional[_ErrorMessagesToOverride] = ...,
    ) -> ChoicesField[_C]
    @overload
    def __new__(
        cls,
        choices_enum: Type[_C],
        verbose_name: Optional[Union[str, bytes]] = ...,
        name: Optional[str] = ...,
        primary_key: bool = ...,
        max_length: Optional[int] = ...,
        unique: bool = ...,
        blank: bool = ...,
        null: Literal[True] = ...,
        db_index: bool = ...,
        default: Any = ...,
        editable: bool = ...,
        auto_created: bool = ...,
        serialize: bool = ...,
        unique_for_date: Optional[str] = ...,
        unique_for_month: Optional[str] = ...,
        unique_for_year: Optional[str] = ...,
        choices: Optional[_FieldChoices] = ...,
        help_text: str = ...,
        db_column: Optional[str] = ...,
        db_tablespace: Optional[str] = ...,
        validators: Iterable[_ValidatorCallable] = ...,
        error_messages: Optional[_ErrorMessagesToOverride] = ...,
    ) -> ChoicesField[Optional[_C]]
    @overload
    def __get__(self: ChoicesField[_C], instance: Any, owner: Any) -> _C: ...  # type: ignore [override]
    @overload
    def __get__(self: ChoicesField[Optional[_C]], instance: Any, owner: Any) -> Optional[_C]: ...  # type: ignore [override]
    @overload
    def __set__(self, instance: ChoicesField[_C], value: _C) -> None: ...  # type: ignore [override]
    @overload
    def __set__(self, instance: ChoicesField[Optional[_C]], value: Optional[_C]) -> None: ...  # type: ignore [override]

I then created this test model:

class Foo(models.Model):
    class SomeEnum(models.TextChoices):
        FOO = "foo", "Foo Desc"
        BAR = "bar", "bar Desc"

    c = ChoicesField(
        choices_enum=SomeEnum,
        default=SomeEnum.FOO,
    )
    c_nullable = ChoicesField(
        choices_enum=SomeEnum,
        null=True,
        default=SomeEnum.FOO,
    )

Some problems I'm having:

1) pyright says that Foo().c has a type of SomeEnum | None 2) It is saying that it could not match the signature for c_nullable 3) When trying to Foo().c = <something> it not not really validating that I need an instance of SomeEnum

Although I develop in python since 2010 I just started messing with types and thus things are still a little messy for me =P

Also, I think we need to type TextChoices in here because if I try to do some_var: Foo.SomeEnum = Foo.SomeEnum.FOO it says that they are incompatible. TextChoices is a subclass of enum so maybe it should be typed like this? https://github.com/python/typeshed/blob/master/stdlib/enum.pyi (considering its differences, of couse)

ps. and even though I want to implement this field, I still think that providing an optional Literal list of options when using choices as a list of tuples is still useful for typing. So I think this PR still makes sense (if you think that too)

bellini666 commented 3 years ago

Hey @sbdchd , I just created the package here: https://github.com/bellini666/django-choices-field

I'm messing with it to find a way to properly type it.

bellini666 commented 3 years ago

Ok, I just noticed what I was doing wrong... I set pyright to default to python 3.7, but literal only exists after 3.8

Now the typing on https://github.com/bellini666/django-choices-field is working fine! :)

Regarding this PR, like I said, I still it is a nice addition. What do you think? You want me to apply your diff or do you want to do it yoursef?

sbdchd commented 3 years ago

Nice work with the package, looks good!

I'm hesitant to merge the PR because of the explicit type annotation that isn't really checked, so it ends up being a sort of cast. With the current setup users don't have to change their existing field definitions if they want to start using the types

I think the custom field approach is probably better because then users can assert the value is in their expected union of values (or skip that part if they have a unique constraint defined elsewhere) -- essentially allows users more flexibility

bellini666 commented 3 years ago

@sbdchd np, feel free to reject this then! :)

Btw, just updated the package to have both TextChoicesField and IntegerChoicesField in case you are interested in it!