CarliJoy / intersection_examples

Python Typing Intersection examples
MIT License
30 stars 2 forks source link

Example use case for intersection #40

Open Badg opened 7 months ago

Badg commented 7 months ago

While reading through #29, I saw this comment:

I think the most important thing that we're missing are motivating examples.

So I figured I'd chime in with one that I run into very, very frequently! Hopefully this might be of some use. As some background: I make heavy use of decorators, both for "plain" callables as well as classes. In fact, a very common pattern I use is to combine both of them: decorators on methods to indicate some special behavior, and then a decorator on the class to collect those and do some kind of transform on them. I find this to be a pretty clean library API, and I use it all the time, for a wide variety of things. Critical to note here is that the decorators always return the original object, with a few attributes added to it. This is something I've found much more ergonomic than wrapping, but that's a whole separate discussion that I won't go into. The important thing to know is that adding class attributes to classes, or attributes to functions, is extremely common for me.

So without further ado, here is a (simplified) example of the kind of thing I currently use fairly frequently, but cannot correctly type-hint due to missing intersections:

class _Stator(Protocol):
    """This protocol is used to define the actual individual states for
    the state machine.
    """
    _is_stator: ClassVar[Literal[True]]

class _StatorAction(Protocol):
    _stator_cls_required: type[_Stator]

    # Ideally this would also include modifications to the function signature to
    # note that "stator" is required as the first positional arg

def stator_action[C: Callable](stator_action: C) -> Intersection[_StatorAction, C]:
    stator_action._stator_cls_required = get_annotations(stator_action)['stator']
    return stator_action

class StateMachine[SM](Protocol):
    """This protocol describes the return of the @state_machine
    decorator, and contains the implementations for the functions that
    are added into the decorated class.
    """
    _stator_actions: ClassVar[tuple[_StatorAction, ...]]

    @contextlib.contextmanager
    def from_checkpoint(self, stator: _Stator) -> Generator[SM, None, None]:
        """Creates a StatorMachina instance and loads the passed stator
        into it, setting it as the currently active state. Then, upon
        context exit, verifies that the state machine was exited either
        through a paused or final state, or by an error.
        """
        do_some_stuff()
        yield cast(MachinaProtocol, stator_machina)
        do_other_stuff()

def stator[T: type](stator_cls: T) -> Intersection[type[_Stator], T]:
    """Use this decorator to declare a particular node on the state
    diagram, including all of the internal state that needs to be stored
    at that point.

    It may be useful to pair this with a dataclass (and by "may", I mean
    "almost certainly IS"), but this isn't an absolute requirement.
    """
    stator_cls._is_stator = True
    return stator_cls

def stator_action[C: Callable](stator_action: C) -> Intersection[_StatorAction, C]:
    stator_action._stator_cls_required = get_annotations(stator_action)['stator']
    return stator_action

class Transition[S: _Stator]:
    """Return a state node wrapped in a Transition to denote a state
    transition. Can also be used to transition the state machine into
    a paused state.
    """

    def __init__(self, to_state: S):
        self._to_state = to_state

def state_machine[SM](cls: SM) -> Intersection[SM, StateMachine]:
    """This decorator marks a class as a state machine class, gathers
    up all of the stator actions, and constructs the internal machinery
    used by from_checkpoint to enforce correct state transitions.
    """
    cls._stator_actions = gather_stator_actions(cls)
    cls.from_checkpoint = _build_from_checkpoint(cls)
    return cls

This is then coupled with a bunch of other code that operates on _Stator, _StatorAction, etc. Code using this looks something like this (note that this is the unsimplified version; in reality most of those decorators are second-order decorators).

@dataclass
@stator(checkpoint=True)
class DmktSignupFormSubmitted:
    """The first state in both the initiation and reinitiation flows:
    the dmkt signup form was submitted.
    """
    email_address: str
    topics: list[str]
    dmkt_frequency: DirectMarketingFrequency

@dataclass
@stator
class DmktSignupTokenIssued(DmktSignupFormSubmitted):
    """The second state in both the initiation and reinitiation flows:
    we issued the relevant token.
    """
    token_id: UserVisibleUUID
    consent_copy_info: ConsentCopyInfo

@state_machine
class InitiationFromDmktSignup(StateMachine):
    """When a user goes through the direct marketing signup flow (ie,
    the mailing list signup), and no existing someone is found at the
    given email address, this state machine is responsible for guiding
    the person through the initiation flow, from initial form submission
    all the way through final double-opt-in confirmation, including the
    break in the middle for the email verification link.

    Note that, due to race conditions, just because a someone doesn't
    exist at the time of the dmkt signup, doesn't mean they don't
    concurrently complete a registration. So this flow can still result
    in a noop / error condition.
    """

    def __init__(self, *args, db_conn, **kwargs):
        """SUPER IMPORTANT: the db_conn needs to be within a transaction
        already!
        """
        super().__init__(*args, **kwargs)
        self._db_conn = db_conn

    @stator_action(DmktSignupFormSubmitted)
    async def issue_initiation_token(
            self, stator: DmktSignupFormSubmitted
            ) -> Transition[DmktSignupTokenIssued]:
        """The first step in the initiation flow: issuing a signup token
        that stores the submitted info.
        """
        consent_copy_info = await get_copy_info_for_flavor(
            UserfacingCopyFlavor.DMKT_MAILING_LIST_DOUBLEOPTIN_EMAIL)
        token_payload = DmktSignupResumeEncapsulation(
            email_address=stator.email_address,
            topics=stator.topics,
            dmkt_frequency=stator.dmkt_frequency,
            userfacing_copy_version_id=
                consent_copy_info.userfacing_copy_version_id)
        token_id = await issue_anonymous_token(
            {AttestedAction.DMKT_SIGNUP_RESUME: token_payload})
        return Transition(DmktSignupFormSubmitted(
            email_address=stator.email_address,
            topics=stator.topics,
            dmkt_frequency=stator.dmkt_frequency,
            token_id=token_id,
            consnt_copy_info=consent_copy_info))

    @resume_at_stator(DmktSignupEmailClicked)
    @stator_action(DmktSignupTokenIssued)
    async def send_initiation_email(self, stator) -> Transition[Paused]:
        """Once the token has been issued, we can send the initiation
        email containing the consent copy and a link to the token.
        """

# Note: I've removed the web framework stuff here so it's more concise 
async def form_submit_route_handler():
    maybe_someone_id = await get_someone_id_from_email(
        form_payload.email_address)

    form_submitted = DmktSignupFormSubmitted(
        email_address=form_payload.email_address,
        topics=form_payload.subscription_topics,
        dmkt_frequency=form_payload.contact_frequency)

    postgres_pool = Singleton.get(TaetimePostgres)
    async with postgres_pool.acquire_wrapped_conn() as conn:
        if maybe_someone_id is None:
            initiation = InitiationFromDmktSignup(db_conn=conn)

            with initiation.from_checkpoint(form_submitted) as initiation_flow:
                async with conn.transaction():
                    await initiation_flow.issue_initiation_token()
                    await initiation_flow.send_initiation_email()

There are a couple things I want to point out here:

  1. Note that there are 4 times I would use an intersection in the library code, even for that relatively short snippet. I believe there are even more in the test code for it as well (not 100% sure; I'm much laxer on typing with test code)
  2. For library consumers, it's impossible for me to correctly type-hint this code. Either they're missing the state machine actions -- which is terrible UX for someone writing code that interacts with the state machine -- or they're missing from_checkpoint, which is added by the @state_machine decorator. Explicitly subclassing StateMachine (as I did above) is a workaround for the from_checkpoint, but for more sophisticated decorators (or ones that are meant for functions and not for classes), this isn't an option. Furthermore, in reality you might want the @stator decorator to add some methods on to the class -- just like @dataclass does. At that point, you're really SOL.
  3. For library implementers, this also presents a problem. The actual implementations of the above decorators want to make heavy use of _Stator, _StateMachine, etc protocols internally, but they can't do that without sacrificing the API presented to library consumers. So as a result, the types applied to implementation functions are both overly broad (lots of unqualified types) and still fail type checking because the type checker doesn't know about the intersections, so I either have to wrap the implementation in two calls to cast (first to apply the protocol, then to revert to the original type), or I have to add a bunch of type: ignore comments.
  4. It seems to me like a lot of untyped code in the stdlib could probably be given types with the above protocol + intersection pattern; for example, tons of stuff in functools.
Badg commented 7 months ago

PS: hopefully that's actually helpful and not just a wall of text!

vergenzt commented 7 months ago

I've got a similar use case to yours -- a decorator that I cannot currently type annotate adequately b/c it adds attributes to a passed-in class. (I need to say that for the user's type T that they pass in, I return something that's T & MyProtocol.)

Badg commented 7 months ago

PS up top for visibility: I'd be happy to clean these up for use in a PEP. And by that I mean, distill them into "minimum viable examples" without all of the application-specific confusery. If desired, @ me in this issue somewhere once there's a rough draft of the PEP or something; I just need more than zero context to go off of.

Just ran into another use case -- generic protocols and typeguards. Though tbh, I'm not 100% sure there isn't an alternative way of doing this right now; I found, for example, this SO answer which is similar.

Anyways, this is another, completely unrelated set of code to what I posted above -- this time for a library to make it more convenient to define and use test vectors, especially for integration tests -- but I'm again using the same decorators + protocols pattern. Admittedly this is a little bit verbose, but... well, that's a whole separate discussion, and it's late here in CET, almost midnight. So this is the code as it exists right now:

class _TevecType(Protocol):
    _tevec_fields: ClassVar[tuple[str, ...]]
    _tevec_applicators: ClassVar[dict[str, str]]
    __dataclass_fields__: ClassVar[dict[str, Field[Any]]]
    __call__: Callable

def _is_tevec_type(obj) -> TypeGuard[type[_TevecType]]:
    return (
        hasattr(obj, '_tevec_fields')
        and hasattr(obj, '_tevec_applicators')
        and callable(obj)
        and hasattr(obj, '__dataclass_fields__'))

The problem comes every time I use the type guard: the protocol is intended as a mixin, but after calling the type guard, pyright will narrow the type to eliminate all of the existing attributes. So, for example, in the test code:

    def test_field_proxies_on_decorated_class(self):
        @tevec_type
        class Foo:
            foo: int
            bar: Optional[int] = None

        assert _is_tevec_type(Foo)
        assert fields(Foo)
        # Using set here to remove ordering anomalies
        assert set(Foo._tevec_fields) == {'foo', 'bar'}
        assert not Foo._tevec_applicators

        assert hasattr(Foo, 'foo')
        assert hasattr(Foo, 'bar')
        # NOTE: both of these fail type checking because we need an
        # intersection type for the _is_tevec_type type guard
        assert isinstance(Foo.foo, _TevecFieldProxy)
        assert isinstance(Foo.bar, _TevecFieldProxy)

What I want to do instead is something along the lines of:

def _is_tevec_type[T: type](obj: T) -> TypeGuard[Intersection[T, type[_TevecType]]]:
    ...

(except probably I'd need to overload that because anything where obj isn't a type is just going to return False, but I clearly haven't thought that far).

That being said... the "spelling" here is a little bizzare, and I'm having a hard time wrapping my head around it. I think if the type checker has only narrowed the type of obj to type, then it would be a noop, but if there's a more specific type, it would... un-narrow it to include the protocol? That seems a bit odd, but I think the important thing I'm trying to highlight is that, if one of the primary use cases of protocols is to be used to type mixins, but typeguards break the mixin-ness without an intersection, then that would definitely be something to consider.

Edit: almost forgot, while poking around for this, I found these two issues on pyright, which are related:

As this comment mentions, without an intersection type, the problem is basically un-solveable; you can only choose one issue or the other.

gentlegiantJGC commented 6 months ago

I agree that it would be useful to add type hints when patching attributes onto an existing object. Particularly methods. Here is my use case which is quite similar to your example.

from typing import TypeVar, Protocol, Union as Intersection, cast
import inspect

T = TypeVar("T")
F = TypeVar("F")

class MyAttrContainer(Protocol[T]):
    my_attr: T

def add_my_attr(attr: T) -> Callable[[F], Intersection[F, MyAttrContainer[T]]]:
    def wrap(func: F) -> Intersection[F, MyAttrContainer[T]]:
        func_ = cast(Intersection[F, MyAttrContainer[T]], func)
        func_.my_attr = attr
        return func_

    return wrap

class Test:
    @add_my_attr("hello world")
    def test(self) -> str:
        return "test"

t = Test()
print(t.test.my_attr)
print(t.test())
print(inspect.ismethod(t.test))

Edit: For completeness the above example can be achieved using the current typing system but you can't stack decorators otherwise the original would get lost. Intersection would allow stacking multiple attribute adding decorators. It also requires a different implementation for functions and methods which Intersection wouldn't.

A typed example for adding an attribute to a method or function can be found here. https://github.com/Amulet-Team/Amulet-Core/commit/aef557d7f1180f5fc1d43c0c5121aa273cfc598d