Closed saulshanabrook closed 4 months ago
Nits: I noticed that the new test is 1.25MB in size, which is bigger than the current src and tests directory combined. Any chance we can make it smaller?
Yeah it is a bit big. I was trying to find a smaller reproducible version, but it was hard to get the worst-case performance to reproduce manually.
Closes https://github.com/egraphs-good/egglog/issues/388 by using an immutable data type for type checking.
I also added a test case that reproduces the slowdown.
Running
time ./target/release/egglog tests/python_array_optimize.egg
before the immutable data types:And after:
Are there regressions?
I also observed previously when testing this against Python that this would make everything else a bit slower. We don't have performance testing set up in this repo, so it's a bit cumbersome to measure, but for sanity I checked the eggcc example to see. After this change:
Before this change:
I am not sure if the change here is significant, it would be hard to tell without repeated testing and some statistics. But overall, the baseline performance seems comparable.
Generating Example
To generate the example I ran this Python file to reproduce one of the test cases: ```python import time from pathlib import Path from sklearn import config_context, datasets from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from egglog.exp.array_api import * from egglog.exp.array_api_numba import array_api_numba_schedule from egglog.exp.array_api_program_gen import * iris = datasets.load_iris() X_np, y_np = (iris.data, iris.target) def run_lda(x, y): with config_context(array_api_dispatch=True): lda = LinearDiscriminantAnalysis(n_components=2) return lda.fit(x, y).transform(x) def trace_lda(egraph: EGraph): X_arr = NDArray.var("X") assume_dtype(X_arr, X_np.dtype) assume_shape(X_arr, X_np.shape) assume_isfinite(X_arr) y_arr = NDArray.var("y") assume_dtype(y_arr, y_np.dtype) assume_shape(y_arr, y_np.shape) assume_value_one_of(y_arr, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type] with egraph: return run_lda(X_arr, y_arr) egraph = EGraph() expr = trace_lda(egraph) egraph = EGraph(save_egglog_string=True) start_time = time.time() egraph.simplify(expr, array_api_numba_schedule) print(f"Elapsed time: {time.time() - start_time}") (Path(__file__).parent / "tmp.egg").write_text(egraph.as_egglog_string) ``` I also had to patch the function naming code so that it could be properly parsed by egglog: ```diff diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index ec82c5d..bc59d21 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -672,6 +672,9 @@ class _FunctionConstructor: return decls, ref +ALLOWED = "-+*/?!=<>&|^/%_" + + def _fn_body_name(fn: Callable, decls: Declarations, function_decl: FunctionDecl) -> str: """ Creates a function name from the function body. @@ -685,7 +688,9 @@ def _fn_body_name(fn: Callable, decls: Declarations, function_decl: FunctionDecl ] fn = f"lambda {', '.join(signature.arg_names)}: {fn(*args)}" tp = f"Callable[[{', '.join(str(tp) for tp in signature.arg_types)}], {signature.return_type}]" - return f"cast({tp}, {fn})" + res = f"cast({tp}, {fn})" + # replace all non alphanumeric characters with underscores + return "".join(c if c.isalnum() or c in ALLOWED else "_" for c in res) def d(x): ```