Closed laggykiller closed 3 months ago
Regarding the PackagesDict data structure. Can this not be done in an easier way all the way at the end when the import list is being generated? Instead of traversing sorted()
, you could use sorted(.., key=)
where a key looks at the module name that distinguishes between typing
+ collections
, anything starting with self.module.__name__.split('.')[0]
, and everything else
Instead of traversing sorted(), you could use sorted(.., key=)
This would be the best solution, but I want to add extra line break between imports from different party, e.g.:
from typing import Sequence
from numpy.typing import ArrayLike
instead of:
from typing import Sequence
from numpy.typing import ArrayLike
If I am just looping through sorted(self.imports, key=custom_key_func)
It would not be possible to tell where does stdlib imports end and 3rd party imports start -> Unable to add line break
However, I did simplify things a bit in latest commit, please check
Alternative solution:
def check_party(self, module: str) -> Literal[0, 1, 2]:
"""
Check source of module
0 = From stdlib
1 = From 3rd party package
2 = From the package being built
"""
if module.startswith(".") or module == self.module.__name__.split('.')[0]:
return 2
try:
spec = importlib.util.find_spec(module)
except ModuleNotFoundError:
return 1
if spec:
if spec.origin and "site-packages" in spec.origin:
return 1
else:
return 0
else:
return 1
def get(self) -> str:
"""Generate the final stub output"""
s = ""
last_party = None
for module in sorted(self.imports, key=lambda i: str(self.check_party(i)) + i):
imports = self.imports[module]
items: List[str] = []
party = self.check_party(module)
if party != last_party:
if last_party is not None:
s += "\n"
last_party = party
for (k, v1), v2 in imports.items():
if k is None:
if v1 and v1 != module:
s += f"import {module} as {v1}\n"
else:
s += f"import {module}\n"
else:
if k != v2 or v1:
items.append(f"{k} as {v2}")
else:
items.append(k)
items = sorted(items)
if items:
items_v0 = ", ".join(items)
items_v0 = f"from {module} import {items_v0}\n"
items_v1 = "(\n " + ",\n ".join(items) + "\n)"
items_v1 = f"from {module} import {items_v1}\n"
s += items_v0 if len(items_v0) <= 70 else items_v1
s += "\n\n"
s += self.put_abstract_enum_class()
# Append the main generated stub
s += self.output
return s.rstrip() + "\n"
I am not sure if this is cleaner though
I like the version from your last comment better. Could you update the PR to reflect it?
PR updated, please check!
Great, thank you!
See https://github.com/wjakob/nanobind/issues/420#issuecomment-1986957935