bashtage / arch

ARCH models in Python
Other
1.33k stars 247 forks source link

Sped up covariance estimate via numba #686

Closed wolph closed 12 months ago

wolph commented 12 months ago

Benchmarked using: py.test arch/tests/unitroot/test_fmols_ccr.py

Before:

============================= fixture duration top ============================== 
total          name              num   avg            min
0:00:00.080000             trend  1728        0:00:00 0:00:00
0:00:00.062000              diff  1728        0:00:00 0:00:00
0:00:00.032000            kernel  1728        0:00:00 0:00:00
0:00:00.031000         force_int  1728        0:00:00 0:00:00
0:00:00.016000   trivariate_data     2 0:00:00.008000 0:00:00
0:00:00.016000           x_trend  1728        0:00:00 0:00:00
0:00:00.016000         bandwidth  1728        0:00:00 0:00:00
0:00:00.253000       grand total 10426        0:00:00 0:00:00
============================ test call duration top ============================= 
total          name              num   avg            min
0:00:06.062000  test_fmols_smoke   864        0:00:00 0:00:00
0:00:06.001000    test_ccr_smoke   864        0:00:00 0:00:00
0:00:00.093000   test_ccr_eviews    26        0:00:00 0:00:00
0:00:00.078000 test_fmols_eviews    26        0:00:00 0:00:00
0:00:12.234000       grand total  1784        0:00:00 0:00:00
============================ test setup duration top ============================ 
total          name              num   avg            min
0:00:00.377000  test_fmols_smoke   864        0:00:00 0:00:00
0:00:00.295000    test_ccr_smoke   864        0:00:00 0:00:00
0:00:00.672000       grand total  1784        0:00:00 0:00:00
========================== test teardown duration top =========================== 
total          name              num   avg            min
0:00:00.142000  test_fmols_smoke   864        0:00:00 0:00:00
0:00:00.124000    test_ccr_smoke   864        0:00:00 0:00:00
0:00:00.266000       grand total  1784        0:00:00 0:00:00
============================= 1784 passed in 14.37s ============================= 

After:

============================= fixture duration top ============================== 
total          name              num   avg            min
0:00:00.063000             trend  1728        0:00:00 0:00:00
0:00:00.048000           x_trend  1728        0:00:00 0:00:00
0:00:00.032000         force_int  1728        0:00:00 0:00:00
0:00:00.016000   trivariate_data     2 0:00:00.008000 0:00:00
0:00:00.016000            kernel  1728        0:00:00 0:00:00
0:00:00.015000         bandwidth  1728        0:00:00 0:00:00
0:00:00.190000       grand total 10426        0:00:00 0:00:00
============================ test call duration top ============================= 
total          name              num   avg            min
0:00:04.528000  test_fmols_smoke   864        0:00:00 0:00:00
0:00:03.748000    test_ccr_smoke   864        0:00:00 0:00:00
0:00:00.094000 test_fmols_eviews    26        0:00:00 0:00:00
0:00:00.094000   test_ccr_eviews    26        0:00:00 0:00:00
0:00:08.464000       grand total  1784        0:00:00 0:00:00
============================ test setup duration top ============================ 
total          name              num   avg            min
0:00:00.392000    test_ccr_smoke   864        0:00:00 0:00:00
0:00:00.329000  test_fmols_smoke   864        0:00:00 0:00:00
0:00:00.721000       grand total  1784        0:00:00 0:00:00
========================== test teardown duration top =========================== 
total          name              num   avg            min
0:00:00.109000  test_fmols_smoke   864        0:00:00 0:00:00
0:00:00.064000    test_ccr_smoke   864        0:00:00 0:00:00
0:00:00.016000   test_ccr_eviews    26        0:00:00 0:00:00
0:00:00.189000       grand total  1784        0:00:00 0:00:00
============================= 1784 passed in 10.34s ============================= 

Nothing too dramatic but a useful performance improvement :)

I'm still looking at other bottlenecks right now so depending on your preference I can create separate pull requests or a single large one :)

codecov[bot] commented 12 months ago

Codecov Report

All modified lines are covered by tests :white_check_mark:

Files Coverage Δ
arch/covariance/kernel.py 100.00% <100.00%> (ø)

:loudspeaker: Thoughts on this report? Let us know!.

bashtage commented 12 months ago

Looks good. Can you run isort for fix imports. ?

wolph commented 12 months ago

I actually ran isort, black and flake8 and none of those show issues with my code.

I'm guessing there's been a black update this is the diff that black is producing now:

$ git diff
diff --git a/arch/univariate/base.py b/arch/univariate/base.py
index 4847000f8..f59fd7c1c 100644
--- a/arch/univariate/base.py
+++ b/arch/univariate/base.py
@@ -1189,10 +1189,7 @@ class ARCHModelFixedResult(_SummaryRepr):

         stubs = list(self._names)
         header = ["coef"]
-        param_table_data = [
-            [format_float_fixed(param, 10, 4)]
-            for param in self.params
-        ]
+        param_table_data = [[format_float_fixed(param, 10, 4)] for param in self.params]

         mc = self.model.num_params
         vc = self.model.volatility.num_params

I can apply it if you want to, but it's not really part of this PR :)

The other issue is (was, I just pushed an update) that the signature for the jit function is incorrect. It doesn't allow usage as a decorator with arguments

wolph commented 12 months ago

Disregard part of that last comment... I was looking at the wrong branch ;)

The diff above still applies though

bashtage commented 12 months ago

To get the decorator to work, I thnk you can use

try:
    if DISABLE_NUMBA:
        raise ImportError

    from numba import jit

    jit = functools.partial(jit, nopython=True)

except ImportError:

    def jit(
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        if args and callable(args[0]):
            func = args[0]

        def wrap(func):
            def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
                import warnings

                warnings.warn(performance_warning, PerformanceWarning)
                return func(*args, **kwargs)

            return wrapper

        return wrap
bashtage commented 12 months ago

I pushed an improved version of the fix.

bashtage commented 12 months ago

Thanks. Merged these in after some surgery on the compact layer.

wolph commented 11 months ago

Thanks for the quick merge and the awesome project!