adriangb / di

Pythonic dependency injection
https://www.adriangb.com/di/
MIT License
303 stars 13 forks source link

bug: pep 563 (class __init__ annotations parsing) #105

Closed maxzhenzhera closed 1 year ago

maxzhenzhera commented 1 year ago

Example

Modified version of di/docs_src/autowiring.py:

from __future__ import annotations

import asyncio
from dataclasses import dataclass

from di import Container
from di.dependent import Dependent
from di.executors import AsyncExecutor

@dataclass
class Config:
    host: str = "localhost"

class DBConn:
    def __init__(self, config: Config) -> None:
        self.host = config.host

async def endpoint(conn: DBConn) -> None:
    assert isinstance(conn, DBConn)

async def framework():
    container = Container()
    solved = container.solve(Dependent(endpoint, scope="request"), scopes=["request"])
    async with container.enter_scope("request") as state:
        await solved.execute_async(executor=AsyncExecutor(), state=state)

if __name__ == "__main__":
    asyncio.run(framework())

Traceback:

Traceback (most recent call last):
  File "/home/maxzhenzhera/repos/di/docs_src/autowiring.py", line 33, in <module>
    asyncio.run(framework())
  File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/home/maxzhenzhera/repos/di/docs_src/autowiring.py", line 27, in framework
    solved = container.solve(Dependent(endpoint, scope="request"), scopes=["request"])
  File "/home/maxzhenzhera/repos/di/di/_container.py", line 649, in solve
    return solve(dependency, scopes, self._bind_hooks, scope_resolver)
  File "/home/maxzhenzhera/repos/di/di/_container.py", line 476, in solve
    root_task = build_task(
  File "/home/maxzhenzhera/repos/di/di/_container.py", line 319, in build_task
    child_task = build_task(
  File "/home/maxzhenzhera/repos/di/di/_container.py", line 307, in build_task
    params = get_params(dependency, binds, path)
  File "/home/maxzhenzhera/repos/di/di/_container.py", line 252, in get_params
    raise WiringError(
di.exceptions.WiringError: The parameter config to <class '__main__.DBConn'> has no dependency marker, no type annotation and no default value. This will produce a TypeError when this function is called. You must either provide a dependency marker, a type annotation or a default value.
Path: Dependent(call=<function endpoint at 0x7f797e183e20>, use_cache=True) -> Dependent(call=<class '__main__.DBConn'>, use_cache=True)

Here we can note:

Explanation

The problem is how real annotations parsed:

https://github.com/adriangb/di/blob/f8b0f4b38e6f43c4b5365bac1c663b64a60afefd/di/_utils/inspect.py#L68-L83 annotations = get_annotations(call) https://github.com/adriangb/di/blob/f8b0f4b38e6f43c4b5365bac1c663b64a60afefd/di/_utils/inspect.py#L48-L65

Why it occur only with from __future__ import annotations?

If do not perform this future import, so, leave annotations as is (real types), then => get_parameters() (that works only with signature.inspect) is enough - since annotations are real types:

if inspect.isclass(call) and (call.__new__ is not object.__new__):  # type: ignore[comparison-overlap] 
         # classes overriding __new__, including some generic metaclasses, result in __new__ getting read 
         # instead of __init__ 
         params = inspect.signature(call.__init__).parameters  # type: ignore[misc] # accessing __init__ directly 
         params = dict(params) 
         params.pop(next(iter(params.keys())))  # first parameter to __init__ is self 
     else: 
         params = inspect.signature(call).parameters 

What's wrong?

In my case, I perform future import, so, annotations are stringized. Therefore, get_parameters() (that works only with signature.inspect) is NOT enough and we dive into get_annotations():

     if not ( 
         inspect.isclass(call) or inspect.isfunction(call) or inspect.ismethod(call) 
     ) and hasattr(call, "__call__"): 
         # callable class 
         types_from = call.__call__  # type: ignore[misc,operator] # accessing __init__ directly 
     else: 
         # method 
         types_from = call 

There is no handling of __init__

We can check it with:

from di._utils.inspect import get_annotations

print(get_annotations(DBConn))  # {}

So, the dependency call for Config is not built, since real annotation is left in stringized form and not replaced with real annotation as it was expected.

Workaround

It works as expected if I just get type hints from class __init__ if it is a class:

def get_annotations(call: Callable[..., Any]) -> Dict[str, Any]:
    types_from: Callable[..., Any]
    if not (
        inspect.isclass(call) or inspect.isfunction(call) or inspect.ismethod(call)
    ) and hasattr(call, "__call__"):
        # callable class
        types_from = call.__call__  # type: ignore[misc,operator] # accessing __init__ directly
    else:
        # method
        types_from = call

    #############################################
    # handle init
    if inspect.isclass(call):
        types_from = call.__init__
    #############################################

    hints = get_type_hints(types_from, include_extras=True)
    # for no apparent reason, Annotated[Optional[T]] comes back as Optional[Annotated[Optional[T]]]
    # so remove the outer Optional if this is the case
    for param_name, hint in hints.items():
        args = get_args(hint)
        if get_origin(hint) is Union and get_origin(next(iter(args))) is Annotated:
            hints[param_name] = next(iter(args))
    return hints

The previous example now works correctly:

from di._utils.inspect import get_annotations

print(get_annotations(DBConn))  # {'config': <class '__main__.Config'>, 'return': <class 'NoneType'>}

Note: breaking commit https://github.com/adriangb/di/commit/722ede44bea24b6c2aa0df6d150053e7772fe1b9

adriangb commented 1 year ago

If you have a fix please feel free to submit a PR