jacopok / mlgw_bns

Accelerating gravitational wave template generation with machine learning.
GNU General Public License v3.0
4 stars 2 forks source link

Speed up `Model` prediction #7

Open jacopok opened 2 years ago

jacopok commented 2 years ago

Specifically, the ideas are

jacopok commented 2 years ago

Simply putting in linear interpolation has the effect of making the downsampled arrays way too large; reducing the tolerance then fixes this. When running the test suite with degree = 1 and tolerance = 1e-5 for the DownsamplingTraining I get a bunch of errors, seemingly something breaks and the PCA feeds NaNs to the NN.

Still, some benchmarks are showing through and they look promising, around a ~30 to 40% improvement in the evaluation of a waveform!

It's probably worth investigating this a little bit more.

jacopok commented 2 years ago

Solving https://github.com/jacopok/mlgw_bns/issues/14 addressed that problem; now however it's still unclear whether this would be a concrete benefit.

Little data dump: with degree 1 polynomials and tolerance 1e-5 I get

------------------------------------------------------------------------------------------ benchmark 'model-prediction': 16 tests -----------------------------------------------------------------------------------------
Name (time in ms)                                                Min                Max              Mean            StdDev            Median               IQR            Outliers       OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_model_nn_prediction[default_model-1e-05-0.002-128]        1.2580 (1.0)       3.7685 (1.0)       1.4330 (1.0)      0.2074 (1.32)      1.4085 (1.0)      0.0249 (1.33)         5;30  697.8146 (1.0)         300           1
test_model_nn_prediction[default_model-1e-05-0.002-256]        1.2802 (1.02)     14.1466 (3.75)      1.6364 (1.14)     0.8819 (5.59)      1.4589 (1.04)     0.0483 (2.58)        37;64  611.0807 (0.88)        569           1
test_model_nn_prediction[default_model-1e-05-0.002-512]        1.2995 (1.03)     19.7520 (5.24)      1.8235 (1.27)     1.0200 (6.47)      1.7167 (1.22)     0.0769 (4.10)       22;116  548.3904 (0.79)        527           1
test_model_nn_prediction[default_model-1e-05-0.002-1024]       1.5321 (1.22)     16.5949 (4.40)      2.2405 (1.56)     1.7813 (11.30)     1.8142 (1.29)     0.0805 (4.30)        34;96  446.3300 (0.64)        495           1
test_model_nn_prediction[default_model-1e-05-0.002-2048]       1.8507 (1.47)      6.7831 (1.80)      2.1474 (1.50)     0.7038 (4.46)      1.9888 (1.41)     0.0187 (1.0)        26;101  465.6855 (0.67)        422           1
test_model_nn_prediction[default_model-1e-05-0.002-4096]       1.9665 (1.56)      7.1243 (1.89)      2.4454 (1.71)     0.7665 (4.86)      2.3062 (1.64)     0.0761 (4.06)        21;58  408.9358 (0.59)        390           1
test_model_nn_prediction[default_model-1e-05-0.002-8192]       2.4725 (1.97)     13.0120 (3.45)      3.1760 (2.22)     1.4534 (9.22)      2.9451 (2.09)     0.1048 (5.59)        14;67  314.8658 (0.45)        327           1
test_model_nn_prediction[default_model-1e-05-0.002-16384]      3.5232 (2.80)     23.2351 (6.17)      4.5940 (3.21)     1.9053 (12.08)     4.2386 (3.01)     0.2093 (11.17)       18;47  217.6746 (0.31)        238           1
test_model_nn_prediction[trained_model-0.015-0.1-1024]        10.7251 (8.53)     26.5114 (7.04)     14.4041 (10.05)    1.6820 (10.67)    14.1557 (10.05)    0.3602 (19.22)         7;9   69.4247 (0.10)         74           1
test_model_nn_prediction[trained_model-0.015-0.1-128]         12.7998 (10.17)    13.2045 (3.50)     12.9819 (9.06)     0.1577 (1.0)      12.9613 (9.20)     0.2358 (12.58)         2;0   77.0305 (0.11)          5           1
test_model_nn_prediction[trained_model-0.015-0.1-256]         12.8385 (10.21)    30.4628 (8.08)     14.2866 (9.97)     2.3101 (14.65)    13.7358 (9.75)     0.4837 (25.81)        2;10   69.9957 (0.10)         75           1
test_model_nn_prediction[trained_model-0.015-0.1-512]         13.5945 (10.81)    29.9235 (7.94)     14.3677 (10.03)    1.9947 (12.65)    13.8546 (9.84)     0.4409 (23.53)        3;10   69.6006 (0.10)         73           1
test_model_nn_prediction[trained_model-0.015-0.1-2048]        14.0007 (11.13)    28.0412 (7.44)     15.6125 (10.89)    3.3852 (21.47)    14.4004 (10.22)    0.2377 (12.68)        8;11   64.0512 (0.09)         73           1
test_model_nn_prediction[trained_model-0.015-0.1-4096]        14.3069 (11.37)    32.5939 (8.65)     15.6227 (10.90)    2.9918 (18.98)    14.7948 (10.50)    0.2258 (12.05)        6;10   64.0095 (0.09)         71           1
test_model_nn_prediction[trained_model-0.015-0.1-8192]        15.1870 (12.07)    29.6823 (7.88)     16.4394 (11.47)    3.0402 (19.28)    15.6051 (11.08)    0.2849 (15.20)         5;6   60.8295 (0.09)         68           1
test_model_nn_prediction[trained_model-0.015-0.1-16384]       16.5819 (13.18)    39.2085 (10.40)    18.2885 (12.76)    4.7163 (29.91)    16.9444 (12.03)    0.2643 (14.10)         4;5   54.6790 (0.08)         51           1

and the reconstruction from the default model is not accurate enough, since that's trained with cubics.

On the other hand, with the regular degree 3 and tolerance 1e-6 I get

------------------------------------------------------------------------------------------ benchmark 'model-prediction': 16 tests -----------------------------------------------------------------------------------------
Name (time in ms)                                                Min                Max              Mean            StdDev            Median               IQR            Outliers       OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_model_nn_prediction[trained_model-0.015-0.1-256]         1.3293 (1.0)       6.0348 (3.31)     1.4915 (1.0)      0.2724 (1.46)     1.4630 (1.02)     0.0442 (1.67)        16;45  670.4851 (1.0)         535           1
test_model_nn_prediction[trained_model-0.015-0.1-128]         1.3690 (1.03)      1.8233 (1.0)      1.5113 (1.01)     0.1860 (1.0)      1.4348 (1.0)      0.2228 (8.41)          1;0  661.6996 (0.99)          5           1
test_model_nn_prediction[trained_model-0.015-0.1-512]         1.4139 (1.06)      9.4266 (5.17)     1.6742 (1.12)     0.6651 (3.58)     1.5254 (1.06)     0.0265 (1.0)        23;146  597.3100 (0.89)        518           1
test_model_nn_prediction[trained_model-0.015-0.1-1024]        1.5398 (1.16)     10.1602 (5.57)     1.8579 (1.25)     0.6880 (3.70)     1.6753 (1.17)     0.0680 (2.57)        35;89  538.2405 (0.80)        506           1
test_model_nn_prediction[default_model-1e-05-0.002-128]       1.6048 (1.21)     23.4007 (12.83)    2.1189 (1.42)     1.5349 (8.25)     1.8566 (1.29)     0.0643 (2.43)         9;65  471.9442 (0.70)        297           1
test_model_nn_prediction[default_model-1e-05-0.002-256]       1.6584 (1.25)     13.8404 (7.59)     1.9755 (1.32)     0.5949 (3.20)     1.9189 (1.34)     0.0659 (2.49)        15;49  506.2134 (0.75)        477           1
test_model_nn_prediction[default_model-1e-05-0.002-512]       1.6843 (1.27)      5.7041 (3.13)     2.0158 (1.35)     0.2356 (1.27)     1.9935 (1.39)     0.0610 (2.30)        24;51  496.0744 (0.74)        446           1
test_model_nn_prediction[trained_model-0.015-0.1-2048]        1.7184 (1.29)     15.5453 (8.53)     2.1199 (1.42)     0.9565 (5.14)     1.9298 (1.35)     0.0647 (2.44)        20;84  471.7265 (0.70)        459           1
test_model_nn_prediction[default_model-1e-05-0.002-2048]      1.7818 (1.34)     14.3200 (7.85)     2.6234 (1.76)     0.8604 (4.62)     2.4196 (1.69)     0.1698 (6.41)        25;73  381.1890 (0.57)        341           1
test_model_nn_prediction[default_model-1e-05-0.002-1024]      1.8117 (1.36)      4.8431 (2.66)     2.1859 (1.47)     0.2734 (1.47)     2.1310 (1.49)     0.0300 (1.13)       34;109  457.4687 (0.68)        429           1
test_model_nn_prediction[trained_model-0.015-0.1-4096]        2.1469 (1.62)     13.0112 (7.14)     2.7577 (1.85)     1.2618 (6.78)     2.4150 (1.68)     0.0643 (2.43)        25;82  362.6270 (0.54)        382           1
test_model_nn_prediction[default_model-1e-05-0.002-4096]      2.5948 (1.95)     24.1553 (13.25)    3.9541 (2.65)     2.1245 (11.42)    3.3725 (2.35)     0.3505 (13.23)       20;66  252.9022 (0.38)        325           1
test_model_nn_prediction[trained_model-0.015-0.1-8192]        2.9343 (2.21)      5.9702 (3.27)     3.4748 (2.33)     0.3938 (2.12)     3.3589 (2.34)     0.0916 (3.46)        32;53  287.7829 (0.43)        288           1
test_model_nn_prediction[default_model-1e-05-0.002-8192]      4.1177 (3.10)     15.1022 (8.28)     4.5914 (3.08)     1.0142 (5.45)     4.4149 (3.08)     0.0448 (1.69)         7;56  217.7967 (0.32)        210           1
test_model_nn_prediction[trained_model-0.015-0.1-16384]       4.6797 (3.52)     17.4683 (9.58)     5.5501 (3.72)     1.2305 (6.61)     5.2524 (3.66)     0.2354 (8.89)        10;31  180.1772 (0.27)        196           1
test_model_nn_prediction[default_model-1e-05-0.002-16384]     5.4990 (4.14)     12.6775 (6.95)     6.8500 (4.59)     0.7650 (4.11)     6.7874 (4.73)     0.3718 (14.04)       34;34  145.9861 (0.22)        166           1
jacopok commented 2 years ago

So, it seems quite clear that there are two forces pulling to either side: on one hand, linear splines are faster by about a factor 2, but using linear splines while still getting the required accuracy means getting a lot more points.

How many more, exactly, and how far does the accuracy need to be pushed to salvage the mismatches? Hard to say, without a proper investigation I'd be hard-pressed to make a judgement, so far it seems safe to say that there are tradeoffs and the cubics seem "safer" until the investigation is made.

jacopok commented 2 years ago

If we're using cubic splines, increasing the tolerance in the greedy downsampling to 1e-5 seems not to impact the reconstruction at all, so that's great! Maybe we can increase it even a bit more.