python / mypy

Optional static typing for Python
https://www.mypy-lang.org/
Other
18.18k stars 2.77k forks source link

Polynomial time-complexity of typing.overload #10004

Open pelson opened 3 years ago

pelson commented 3 years ago

Feature

The time complexity of dealing with function overloads in mypy appears to be polynomial. Naively, it seems this is not a fundamental limit (I have no experience in static analysis to base this on), below is a log-log graph of number of overloads vs execution time:

Figure_1

The data for the graph was generated with the following code: ``` import json import pathlib import tempfile import textwrap import time from mypy import api def generate_n_overloads(n: int, types=('str', 'int', 'float')) -> str: result = ['import typing'] for name_count in range(n): type_name = types[name_count % len(types)] name = f'overload_{name_count}_{type_name}' annotation = ( textwrap.dedent(f''' @typing.overload def get_param(param: typing.Literal['{name}']) -> {type_name}: ... ''') ) result.append(annotation) return '\n'.join(result) def build_test_case(dest_dir: pathlib.Path, n_overloads: int) -> pathlib.Path: pkg_dir = dest_dir / 'example_overloads_pkg' pkg_dir.mkdir(exist_ok=True) stubfile = pkg_dir / '__init__.pyi' stubfile.write_text(generate_n_overloads(n_overloads)) pkg_file = pkg_dir / '__init__.py' pkg_file.write_text('def get_param(param): ...') test_script = dest_dir / 'script.py' test_script.write_text(textwrap.dedent(''' from example_overloads_pkg import get_param get_param('overload_3_str').upper() # Good get_param('overload_2_float').upper() # Bad ''')) return test_script if __name__ == '__main__': tmpdir = tempfile.TemporaryDirectory() scale_times = {} numbers_to_check = [16, 32, 64, 128, 256, 512, 1024] # numbers_to_check = [16, 32, 64, 128, 256, 300, 400, 500, 600, 700, 800, 900, 1000, 1500, 2000, 2500, 3000, 4000, 5000, 6000, 8000, 10000] for number_of_overloads in numbers_to_check: print(f'Overload size: {number_of_overloads}') test_script = build_test_case(pathlib.Path(tmpdir.name), number_of_overloads) start = time.perf_counter() result = api.run([str(test_script)]) end = time.perf_counter() elapsed = end - start print(' Time:', end - start) if scale_times: last_n = next(reversed(scale_times)) print(f" Slowdown: x{elapsed / scale_times[last_n]:.2f}") scale_times[number_of_overloads] = elapsed # Save the times to json for plotting/analysis. with open('times.json', 'wt') as fh: json.dump(scale_times, fh) ``` And the plot: ``` import json import matplotlib.pyplot as plt import numpy as np with open('times.json', 'rt') as fh: times = json.load(fh) n_points, times = zip(*times.items()) n_points = np.array(n_points, dtype=int) times = np.array(times) fit = np.polyfit(n_points, times, 2) p = np.poly1d(fit) plt.title('Number of overloads vs execution time of mypy') plt.xscale('log') plt.yscale('log') plt.scatter(n_points, times) plt.xlabel('Number of overloads') plt.ylabel('Time to run mypy / seconds') plt.show() ``` The full dataset (``times.json``): ``` {"16": 0.11653269396629184, "32": 0.13803347101202235, "64": 0.1889677940052934, "128": 0.42529276601271704, "256": 1.4768535409821197, "300": 2.165687720000278, "400": 3.581374307977967, "500": 5.615040614036843, "600": 8.182245634961873, "700": 10.219899011019152, "800": 14.204097100009676, "900": 19.01026200200431, "1000": 21.914704270020593, "1500": 50.26287814299576, "2000": 88.03798637498403, "2500": 140.229756515997, "3000": 206.8722462329897, "4000": 373.4201490940177, "5000": 596.6820141010103, "6000": 831.5039637690061, "8000": 1579.1185438489774, "10000": 2624.7711193099967} ```

I have inherited a library which has a very simple DSL to query data from a database and which returns a dictionary of different types depending on the input. Indeed, the language is so simple that I could map these to a collection of literal overloads, with a generic fallback for new/unmapped data. In practice this looks something like:

import typing

@typing.overload
def get_param(param: typing.Literal['some_float']) -> float: ...

@typing.overload
def get_param(param: typing.Literal['some_other_int']) -> int: ...

@typing.overload
def get_param(param: typing.Literal['further_str']) -> str: ...

Being able to do static analysis (and completion) on such an interface would be hugely beneficial to its users for correctness and ease of use. Unfortunately the total number of overloads would need to be in the order of 100,000... clearly from the graph, this number is prohibitive with the current implementation of mypy's overload functionality.

Pitch

I'm therefore reaching out to find out if there is a known good reason for this time-complexity, and potentially seeking help/pointers to address the issue in the implementation.

Other places this could be useful

Truth be told, this is a highly specialised requirement which would apply to a fairly limited set of use cases (some of which are likely anti-patterns in the first place).

Given the high number of overloads involved for this time-complexity to become an issue (at around 100 overloads), it is clear that you'd almost always want to generate the stubs rather than hand-craft them. Examples of where it could come in handy are other string based lookups - I can imagine for example wanting to generate stubs for all of the JClasses that exist in JPype (which can be any Java classes available in the JVM), or perhaps for some interface that allows you to get hold of a CSS property by name and which returns a slightly different type for each of the ~520 valid tags. I could also imagine this being something that could be used for generating the permutations of a multiple-dispatch system.

henribru commented 3 years ago

This might be of some relevance: https://github.com/microsoft/pyright/commit/d9f621e19cef94675d6fe23e4208c98f9d863d6b Seems Pyright only avoids quadratic complexity by skipping some form of consistency check when the number of overloads are large enough.

I'm also interested in this feature, as I maintain one stubs package which currently uses about 300 overloads and another which would benefit from overloads but would require around 10 000 of them.

henribru commented 3 years ago

Looks like a consistency check is probably the cause for Mypy as well: https://github.com/python/mypy/blob/master/mypy/checker.py#L486

hauntsaninja commented 3 years ago

https://github.com/python/mypy/pull/10922 can help here. This skips the quadratic check in files which mypy wouldn't have reported errors in (i.e. third party code), for instance, installed stub packages.