wjakob / nanobind

nanobind: tiny and efficient C++/Python bindings
BSD 3-Clause "New" or "Revised" License
2.14k stars 161 forks source link

Sorted imports #462

Closed laggykiller closed 3 months ago

laggykiller commented 3 months ago

See https://github.com/wjakob/nanobind/issues/420#issuecomment-1986957935

wjakob commented 3 months ago

Regarding the PackagesDict data structure. Can this not be done in an easier way all the way at the end when the import list is being generated? Instead of traversing sorted(), you could use sorted(.., key=) where a key looks at the module name that distinguishes between typing + collections, anything starting with self.module.__name__.split('.')[0], and everything else

laggykiller commented 3 months ago

Instead of traversing sorted(), you could use sorted(.., key=)

This would be the best solution, but I want to add extra line break between imports from different party, e.g.:

from typing import Sequence

from numpy.typing import ArrayLike

instead of:

from typing import Sequence
from numpy.typing import ArrayLike

If I am just looping through sorted(self.imports, key=custom_key_func) It would not be possible to tell where does stdlib imports end and 3rd party imports start -> Unable to add line break

However, I did simplify things a bit in latest commit, please check

laggykiller commented 3 months ago

Alternative solution:

    def check_party(self, module: str) -> Literal[0, 1, 2]:
        """
        Check source of module
        0 = From stdlib
        1 = From 3rd party package
        2 = From the package being built
        """
        if module.startswith(".") or module == self.module.__name__.split('.')[0]:
            return 2

        try:
            spec = importlib.util.find_spec(module)
        except ModuleNotFoundError:
            return 1

        if spec:
            if spec.origin and "site-packages" in spec.origin:
                return 1
            else:
                return 0
        else:
            return 1

    def get(self) -> str:
        """Generate the final stub output"""
        s = ""
        last_party = None

        for module in sorted(self.imports, key=lambda i: str(self.check_party(i)) + i):
            imports = self.imports[module]
            items: List[str] = []
            party = self.check_party(module)

            if party != last_party:
                if last_party is not None:
                    s += "\n"
                last_party = party

            for (k, v1), v2 in imports.items():
                if k is None:
                    if v1 and v1 != module:
                        s += f"import {module} as {v1}\n"
                    else:
                        s += f"import {module}\n"
                else:
                    if k != v2 or v1:
                        items.append(f"{k} as {v2}")
                    else:
                        items.append(k)

            items = sorted(items)
            if items:
                items_v0 = ", ".join(items)
                items_v0 = f"from {module} import {items_v0}\n"
                items_v1 = "(\n    " + ",\n    ".join(items) + "\n)"
                items_v1 = f"from {module} import {items_v1}\n"
                s += items_v0 if len(items_v0) <= 70 else items_v1

        s += "\n\n"
        s += self.put_abstract_enum_class()

        # Append the main generated stub
        s += self.output

        return s.rstrip() + "\n"

I am not sure if this is cleaner though

wjakob commented 3 months ago

I like the version from your last comment better. Could you update the PR to reflect it?

laggykiller commented 3 months ago

PR updated, please check!

wjakob commented 3 months ago

Great, thank you!