rasbt / mlxtend

A library of extension and helper modules for Python's data analysis and machine learning libraries.
https://rasbt.github.io/mlxtend/
Other
4.86k stars 856 forks source link

why is bias_variance_decomp reporting different MSE than sklearn.metrics mean_squared_error #745

Closed jbielski closed 3 years ago

jbielski commented 3 years ago

I added mlextend to my python code in order to use bias_variance_decomp; I'd already been using sklearn.metrics.mean_squared_error. I noticed that the output reported for MSE, and then obviously bias and variance were very large and quite different than what sklearn.metrics.mean_squared_error reported. Here is the output:

35/35 [==============================] - 0s 2ms/step - loss: 173.9796 - mae: 9.4905 - mse: 173.9796 35/35 [==============================] - 0s 1ms/step - loss: 130.8012 - mae: 8.5116 - mse: 130.8012 35/35 [==============================] - 0s 2ms/step - loss: 160.6490 - mae: 9.0809 - mse: 160.6490 35/35 [==============================] - 0s 2ms/step - loss: 135.1470 - mae: 8.5490 - mse: 135.1470 35/35 [==============================] - 0s 2ms/step - loss: 147.8825 - mae: 8.8470 - mse: 147.8825 35/35 [==============================] - 0s 1ms/step - loss: 146.6789 - mae: 8.9765 - mse: 146.6789 35/35 [==============================] - 0s 1ms/step - loss: 158.7630 - mae: 9.0425 - mse: 158.7630 35/35 [==============================] - 0s 1ms/step - loss: 166.1496 - mae: 8.8941 - mse: 166.1496 35/35 [==============================] - 0s 1ms/step - loss: 155.4472 - mae: 9.1264 - mse: 155.4472 35/35 [==============================] - 0s 1ms/step - loss: 131.1109 - mae: 8.5182 - mse: 131.1109 35/35 [==============================] - 0s 2ms/step - loss: 183.1126 - mae: 9.8470 - mse: 183.1126 35/35 [==============================] - 0s 2ms/step - loss: 122.8194 - mae: 8.0459 - mse: 122.8194 35/35 [==============================] - 0s 2ms/step - loss: 141.6612 - mae: 8.8127 - mse: 141.6612 35/35 [==============================] - 0s 1ms/step - loss: 165.4881 - mae: 9.1495 - mse: 165.4881 35/35 [==============================] - 0s 1ms/step - loss: 165.2237 - mae: 9.1524 - mse: 165.2237 35/35 [==============================] - 0s 2ms/step - loss: 173.1414 - mae: 9.0957 - mse: 173.1414 35/35 [==============================] - 0s 2ms/step - loss: 135.0640 - mae: 8.5834 - mse: 135.0640 35/35 [==============================] - 0s 1ms/step - loss: 140.1290 - mae: 8.6145 - mse: 140.1290 35/35 [==============================] - 0s 2ms/step - loss: 131.1522 - mae: 8.2898 - mse: 131.1522 35/35 [==============================] - 0s 2ms/step - loss: 157.1143 - mae: 8.9523 - mse: 157.1143 35/35 [==============================] - 0s 2ms/step - loss: 138.1463 - mae: 8.3738 - mse: 138.1463 35/35 [==============================] - 0s 2ms/step - loss: 137.5353 - mae: 8.4970 - mse: 137.5353 35/35 [==============================] - 0s 2ms/step - loss: 156.1401 - mae: 9.0615 - mse: 156.1401 35/35 [==============================] - 0s 2ms/step - loss: 155.3580 - mae: 8.9183 - mse: 155.3580 35/35 [==============================] - 0s 2ms/step - loss: 136.6704 - mae: 8.6769 - mse: 136.6704 35/35 [==============================] - 0s 2ms/step - loss: 143.1792 - mae: 8.7125 - mse: 143.1792 35/35 [==============================] - 0s 2ms/step - loss: 152.4702 - mae: 8.9602 - mse: 152.4702 35/35 [==============================] - 0s 2ms/step - loss: 142.4001 - mae: 8.6808 - mse: 142.4001 35/35 [==============================] - 0s 2ms/step - loss: 148.8286 - mae: 8.6396 - mse: 148.8286 35/35 [==============================] - 0s 2ms/step - loss: 141.8189 - mae: 8.7887 - mse: 141.8189 35/35 [==============================] - 0s 2ms/step - loss: 136.6845 - mae: 8.6267 - mse: 136.6845 35/35 [==============================] - 0s 2ms/step - loss: 155.9441 - mae: 8.8864 - mse: 155.9441 35/35 [==============================] - 0s 2ms/step - loss: 157.9773 - mae: 9.1069 - mse: 157.9773 35/35 [==============================] - 0s 1ms/step - loss: 136.8946 - mae: 8.5628 - mse: 136.8946 35/35 [==============================] - 0s 2ms/step - loss: 153.9737 - mae: 9.1242 - mse: 153.9737 35/35 [==============================] - 0s 1ms/step - loss: 132.3296 - mae: 8.3255 - mse: 132.3296 35/35 [==============================] - 0s 2ms/step - loss: 118.2071 - mae: 8.1619 - mse: 118.2071 35/35 [==============================] - 0s 2ms/step - loss: 148.9615 - mae: 8.7149 - mse: 148.9615 35/35 [==============================] - 0s 2ms/step - loss: 161.5860 - mae: 9.1059 - mse: 161.5860 35/35 [==============================] - 0s 1ms/step - loss: 152.3022 - mae: 8.9790 - mse: 152.3022 35/35 [==============================] - 0s 2ms/step - loss: 142.1344 - mae: 8.5953 - mse: 142.1344 35/35 [==============================] - 0s 2ms/step - loss: 142.9644 - mae: 8.6443 - mse: 142.9644 35/35 [==============================] - 0s 2ms/step - loss: 127.6746 - mae: 8.4494 - mse: 127.6746 35/35 [==============================] - 0s 1ms/step - loss: 132.8535 - mae: 8.3585 - mse: 132.8535 35/35 [==============================] - 0s 1ms/step - loss: 126.8244 - mae: 8.3929 - mse: 126.8244 35/35 [==============================] - 0s 1ms/step - loss: 162.6216 - mae: 9.0103 - mse: 162.6216 35/35 [==============================] - 0s 1ms/step - loss: 109.4171 - mae: 8.0277 - mse: 109.4171 35/35 [==============================] - 0s 1ms/step - loss: 127.3269 - mae: 8.4461 - mse: 127.3269 35/35 [==============================] - 0s 1ms/step - loss: 147.7464 - mae: 8.9359 - mse: 147.7464 35/35 [==============================] - 0s 2ms/step - loss: 122.6896 - mae: 8.1013 - mse: 122.6896 35/35 [==============================] - 0s 1ms/step - loss: 123.9316 - mae: 8.2173 - mse: 123.9316 35/35 [==============================] - 0s 2ms/step - loss: 137.0827 - mae: 8.6589 - mse: 137.0827 35/35 [==============================] - 0s 1ms/step - loss: 136.7113 - mae: 8.5353 - mse: 136.7113 35/35 [==============================] - 0s 1ms/step - loss: 139.3192 - mae: 8.5043 - mse: 139.3192 35/35 [==============================] - 0s 2ms/step - loss: 126.8270 - mae: 8.3104 - mse: 126.8270 35/35 [==============================] - 0s 1ms/step - loss: 140.4781 - mae: 8.3527 - mse: 140.4781 35/35 [==============================] - 0s 1ms/step - loss: 128.6389 - mae: 8.1751 - mse: 128.6389 35/35 [==============================] - 0s 2ms/step - loss: 122.3560 - mae: 8.2169 - mse: 122.3560 35/35 [==============================] - 0s 1ms/step - loss: 143.7096 - mae: 8.5355 - mse: 143.7096 35/35 [==============================] - 0s 2ms/step - loss: 137.4444 - mae: 8.5550 - mse: 137.4444 35/35 [==============================] - 0s 1ms/step - loss: 124.1772 - mae: 8.2075 - mse: 124.1772 35/35 [==============================] - 0s 2ms/step - loss: 133.5365 - mae: 8.5046 - mse: 133.5365 35/35 [==============================] - 0s 2ms/step - loss: 139.3011 - mae: 8.4549 - mse: 139.3011 35/35 [==============================] - 0s 1ms/step - loss: 149.2839 - mae: 8.9334 - mse: 149.2839 35/35 [==============================] - 0s 2ms/step - loss: 164.9708 - mae: 9.1750 - mse: 164.9708 35/35 [==============================] - 0s 2ms/step - loss: 109.8826 - mae: 7.6950 - mse: 109.8826 35/35 [==============================] - 0s 1ms/step - loss: 133.8806 - mae: 8.7409 - mse: 133.8806 35/35 [==============================] - 0s 2ms/step - loss: 108.4042 - mae: 7.8433 - mse: 108.4042 35/35 [==============================] - 0s 1ms/step - loss: 117.1420 - mae: 7.9659 - mse: 117.1420 35/35 [==============================] - 0s 2ms/step - loss: 151.9397 - mae: 8.8675 - mse: 151.9397 35/35 [==============================] - 0s 1ms/step - loss: 120.8102 - mae: 8.1778 - mse: 120.8102 35/35 [==============================] - 0s 2ms/step - loss: 121.4617 - mae: 8.1021 - mse: 121.4617 35/35 [==============================] - 0s 2ms/step - loss: 133.6543 - mae: 8.4887 - mse: 133.6543 35/35 [==============================] - 0s 2ms/step - loss: 128.9574 - mae: 8.3054 - mse: 128.9574 35/35 [==============================] - 0s 2ms/step - loss: 120.7032 - mae: 8.2093 - mse: 120.7032 35/35 [==============================] - 0s 2ms/step - loss: 129.3289 - mae: 8.2995 - mse: 129.3289 35/35 [==============================] - 0s 2ms/step - loss: 133.5342 - mae: 8.5464 - mse: 133.5342 35/35 [==============================] - 0s 2ms/step - loss: 116.0511 - mae: 7.9428 - mse: 116.0511 35/35 [==============================] - 0s 1ms/step - loss: 126.8614 - mae: 8.3553 - mse: 126.8614 35/35 [==============================] - 0s 2ms/step - loss: 136.2515 - mae: 8.3411 - mse: 136.2515 35/35 [==============================] - 0s 2ms/step - loss: 126.8929 - mae: 8.1561 - mse: 126.8929 35/35 [==============================] - 0s 1ms/step - loss: 117.1917 - mae: 7.9982 - mse: 117.1917 35/35 [==============================] - 0s 2ms/step - loss: 147.9106 - mae: 8.4827 - mse: 147.9106 35/35 [==============================] - 0s 2ms/step - loss: 140.1210 - mae: 8.5273 - mse: 140.1210 35/35 [==============================] - 0s 1ms/step - loss: 133.0966 - mae: 8.4887 - mse: 133.0966 35/35 [==============================] - 0s 1ms/step - loss: 129.1837 - mae: 8.2691 - mse: 129.1837 35/35 [==============================] - 0s 2ms/step - loss: 137.9485 - mae: 8.3385 - mse: 137.9485 35/35 [==============================] - 0s 2ms/step - loss: 131.4005 - mae: 8.3554 - mse: 131.4005 35/35 [==============================] - 0s 2ms/step - loss: 122.8001 - mae: 8.0379 - mse: 122.8001 35/35 [==============================] - 0s 1ms/step - loss: 129.1283 - mae: 8.2210 - mse: 129.1283 35/35 [==============================] - 0s 2ms/step - loss: 127.9801 - mae: 8.2486 - mse: 127.9801 35/35 [==============================] - 0s 2ms/step - loss: 142.2471 - mae: 8.7803 - mse: 142.2471 35/35 [==============================] - 0s 2ms/step - loss: 119.0921 - mae: 8.0017 - mse: 119.0921 35/35 [==============================] - 0s 1ms/step - loss: 147.3405 - mae: 8.8006 - mse: 147.3405 35/35 [==============================] - 0s 1ms/step - loss: 112.7579 - mae: 7.9488 - mse: 112.7579 35/35 [==============================] - 0s 2ms/step - loss: 120.9991 - mae: 8.1135 - mse: 120.9991 35/35 [==============================] - 0s 2ms/step - loss: 119.6753 - mae: 7.9743 - mse: 119.6753 35/35 [==============================] - 0s 1ms/step - loss: 123.2571 - mae: 8.0348 - mse: 123.2571 35/35 [==============================] - 0s 1ms/step - loss: 122.3880 - mae: 8.0781 - mse: 122.3880 35/35 [==============================] - 0s 1ms/step - loss: 127.3674 - mae: 8.4713 - mse: 127.3674 35/35 [==============================] - 0s 2ms/step - loss: 121.2972 - mae: 8.2487 - mse: 121.2972 35/35 [==============================] - 0s 1ms/step - loss: 125.6018 - mae: 8.1503 - mse: 125.6018 35/35 [==============================] - 0s 2ms/step - loss: 127.3179 - mae: 8.2881 - mse: 127.3179 35/35 [==============================] - 0s 2ms/step - loss: 124.0926 - mae: 8.1012 - mse: 124.0926 35/35 [==============================] - 0s 2ms/step - loss: 127.1927 - mae: 8.3049 - mse: 127.1927 35/35 [==============================] - 0s 2ms/step - loss: 122.4894 - mae: 7.9466 - mse: 122.4894 35/35 [==============================] - 0s 2ms/step - loss: 121.0909 - mae: 8.0331 - mse: 121.0909 35/35 [==============================] - 0s 2ms/step - loss: 118.8251 - mae: 7.9954 - mse: 118.8251 35/35 [==============================] - 0s 2ms/step - loss: 129.3154 - mae: 8.3126 - mse: 129.3154 35/35 [==============================] - 0s 2ms/step - loss: 131.7466 - mae: 8.1973 - mse: 131.7466 35/35 [==============================] - 0s 2ms/step - loss: 123.7850 - mae: 8.2048 - mse: 123.7850 35/35 [==============================] - 0s 2ms/step - loss: 124.8050 - mae: 8.0557 - mse: 124.8050 35/35 [==============================] - 0s 2ms/step - loss: 112.5697 - mae: 7.9108 - mse: 112.5697 35/35 [==============================] - 0s 2ms/step - loss: 130.0370 - mae: 7.8804 - mse: 130.0370 35/35 [==============================] - 0s 1ms/step - loss: 116.7593 - mae: 7.8749 - mse: 116.7593 35/35 [==============================] - 0s 2ms/step - loss: 129.5425 - mae: 8.3358 - mse: 129.5425 35/35 [==============================] - 0s 2ms/step - loss: 135.8403 - mae: 8.4233 - mse: 135.8403 35/35 [==============================] - 0s 2ms/step - loss: 109.8812 - mae: 7.9081 - mse: 109.8812 35/35 [==============================] - 0s 2ms/step - loss: 149.4199 - mae: 8.4815 - mse: 149.4199 35/35 [==============================] - 0s 1ms/step - loss: 124.6325 - mae: 8.3520 - mse: 124.6325 35/35 [==============================] - 0s 2ms/step - loss: 126.1257 - mae: 8.2450 - mse: 126.1257 35/35 [==============================] - 0s 2ms/step - loss: 117.5301 - mae: 7.9030 - mse: 117.5301 35/35 [==============================] - 0s 2ms/step - loss: 115.2620 - mae: 7.8474 - mse: 115.2620 35/35 [==============================] - 0s 2ms/step - loss: 121.3159 - mae: 8.1452 - mse: 121.3159 35/35 [==============================] - 0s 1ms/step - loss: 117.3415 - mae: 7.8663 - mse: 117.3415 35/35 [==============================] - 0s 1ms/step - loss: 129.4238 - mae: 8.3948 - mse: 129.4238 35/35 [==============================] - 0s 2ms/step - loss: 130.7375 - mae: 8.1282 - mse: 130.7375 35/35 [==============================] - 0s 2ms/step - loss: 127.7135 - mae: 8.1718 - mse: 127.7135 35/35 [==============================] - 0s 2ms/step - loss: 116.2874 - mae: 7.9404 - mse: 116.2874 35/35 [==============================] - 0s 1ms/step - loss: 105.4222 - mae: 7.8594 - mse: 105.4222 35/35 [==============================] - 0s 2ms/step - loss: 124.3329 - mae: 8.2165 - mse: 124.3329 35/35 [==============================] - 0s 2ms/step - loss: 122.6872 - mae: 8.0995 - mse: 122.6872 35/35 [==============================] - 0s 2ms/step - loss: 128.9181 - mae: 8.1528 - mse: 128.9181 35/35 [==============================] - 0s 2ms/step - loss: 111.9773 - mae: 7.7923 - mse: 111.9773 35/35 [==============================] - 0s 2ms/step - loss: 131.5503 - mae: 8.4614 - mse: 131.5503 35/35 [==============================] - 0s 2ms/step - loss: 108.5293 - mae: 7.9453 - mse: 108.5293 35/35 [==============================] - 0s 2ms/step - loss: 116.3819 - mae: 8.0942 - mse: 116.3819 35/35 [==============================] - 0s 2ms/step - loss: 119.4842 - mae: 7.9115 - mse: 119.4842 35/35 [==============================] - 0s 2ms/step - loss: 107.9299 - mae: 7.7129 - mse: 107.9299 35/35 [==============================] - 0s 2ms/step - loss: 120.8085 - mae: 8.0006 - mse: 120.8085 35/35 [==============================] - 0s 1ms/step - loss: 112.6447 - mae: 7.6067 - mse: 112.6447 35/35 [==============================] - 0s 2ms/step - loss: 119.1002 - mae: 8.0175 - mse: 119.1002 35/35 [==============================] - 0s 2ms/step - loss: 106.1602 - mae: 7.6995 - mse: 106.1602 35/35 [==============================] - 0s 2ms/step - loss: 108.9758 - mae: 7.8350 - mse: 108.9758 35/35 [==============================] - 0s 2ms/step - loss: 126.9872 - mae: 8.0682 - mse: 126.9872 35/35 [==============================] - 0s 2ms/step - loss: 128.8872 - mae: 8.3490 - mse: 128.8872 35/35 [==============================] - 0s 2ms/step - loss: 111.4356 - mae: 7.6313 - mse: 111.4356 35/35 [==============================] - 0s 2ms/step - loss: 119.5411 - mae: 7.9536 - mse: 119.5411 35/35 [==============================] - 0s 2ms/step - loss: 120.2521 - mae: 8.1189 - mse: 120.2521 35/35 [==============================] - 0s 2ms/step - loss: 116.5888 - mae: 7.7222 - mse: 116.5888 35/35 [==============================] - 0s 2ms/step - loss: 126.2040 - mae: 8.1151 - mse: 126.2040 35/35 [==============================] - 0s 2ms/step - loss: 115.2144 - mae: 7.9930 - mse: 115.2144 35/35 [==============================] - 0s 2ms/step - loss: 111.7060 - mae: 7.8256 - mse: 111.7060 35/35 [==============================] - 0s 2ms/step - loss: 95.0104 - mae: 7.3917 - mse: 95.0104 35/35 [==============================] - 0s 1ms/step - loss: 112.3655 - mae: 7.8621 - mse: 112.3655 35/35 [==============================] - 0s 1ms/step - loss: 109.9810 - mae: 7.7577 - mse: 109.9810 35/35 [==============================] - 0s 2ms/step - loss: 117.0792 - mae: 7.7390 - mse: 117.0792 35/35 [==============================] - 0s 2ms/step - loss: 121.6713 - mae: 7.9948 - mse: 121.6713 35/35 [==============================] - 0s 2ms/step - loss: 124.0231 - mae: 7.8831 - mse: 124.0231 35/35 [==============================] - 0s 2ms/step - loss: 115.3666 - mae: 7.8847 - mse: 115.3666 35/35 [==============================] - 0s 2ms/step - loss: 127.0129 - mae: 8.2836 - mse: 127.0129 35/35 [==============================] - 0s 2ms/step - loss: 108.1284 - mae: 7.6674 - mse: 108.1284 35/35 [==============================] - 0s 2ms/step - loss: 110.8054 - mae: 7.6269 - mse: 110.8054 35/35 [==============================] - 0s 2ms/step - loss: 106.3360 - mae: 7.4058 - mse: 106.3360 35/35 [==============================] - 0s 1ms/step - loss: 114.2057 - mae: 7.9692 - mse: 114.2057 35/35 [==============================] - 0s 1ms/step - loss: 118.5681 - mae: 8.0258 - mse: 118.5681 35/35 [==============================] - 0s 2ms/step - loss: 104.6698 - mae: 7.6472 - mse: 104.6698 35/35 [==============================] - 0s 1ms/step - loss: 104.0010 - mae: 7.4388 - mse: 104.0010 35/35 [==============================] - 0s 2ms/step - loss: 111.3144 - mae: 7.6359 - mse: 111.3144 35/35 [==============================] - 0s 2ms/step - loss: 123.0542 - mae: 7.9958 - mse: 123.0542 35/35 [==============================] - 0s 1ms/step - loss: 100.5310 - mae: 7.4520 - mse: 100.5310 35/35 [==============================] - 0s 2ms/step - loss: 117.6339 - mae: 8.0745 - mse: 117.6339 35/35 [==============================] - 0s 2ms/step - loss: 113.0730 - mae: 7.6076 - mse: 113.0730 35/35 [==============================] - 0s 2ms/step - loss: 124.0812 - mae: 8.0376 - mse: 124.0812 35/35 [==============================] - 0s 2ms/step - loss: 107.7094 - mae: 7.6370 - mse: 107.7094 35/35 [==============================] - 0s 2ms/step - loss: 103.8286 - mae: 7.4406 - mse: 103.8286 35/35 [==============================] - 0s 2ms/step - loss: 109.1548 - mae: 7.7314 - mse: 109.1548 35/35 [==============================] - 0s 2ms/step - loss: 118.4798 - mae: 7.8871 - mse: 118.4798 35/35 [==============================] - 0s 1ms/step - loss: 126.4193 - mae: 8.0505 - mse: 126.4193 35/35 [==============================] - 0s 1ms/step - loss: 100.0260 - mae: 7.3638 - mse: 100.0260 35/35 [==============================] - 0s 1ms/step - loss: 105.5283 - mae: 7.6150 - mse: 105.5283 35/35 [==============================] - 0s 1ms/step - loss: 102.6057 - mae: 7.6035 - mse: 102.6057 35/35 [==============================] - 0s 2ms/step - loss: 105.3562 - mae: 7.3650 - mse: 105.3562 35/35 [==============================] - 0s 2ms/step - loss: 112.6658 - mae: 7.8923 - mse: 112.6658 35/35 [==============================] - 0s 1ms/step - loss: 98.4838 - mae: 7.3733 - mse: 98.4838 35/35 [==============================] - 0s 2ms/step - loss: 109.5082 - mae: 7.5858 - mse: 109.5082 35/35 [==============================] - 0s 2ms/step - loss: 106.8641 - mae: 7.5266 - mse: 106.8641 35/35 [==============================] - 0s 1ms/step - loss: 119.8303 - mae: 7.7595 - mse: 119.8303 35/35 [==============================] - 0s 2ms/step - loss: 114.1899 - mae: 7.6324 - mse: 114.1899 35/35 [==============================] - 0s 1ms/step - loss: 110.3103 - mae: 7.5866 - mse: 110.3103 35/35 [==============================] - 0s 1ms/step - loss: 104.0811 - mae: 7.6450 - mse: 104.0811 35/35 [==============================] - 0s 2ms/step - loss: 99.9281 - mae: 7.5091 - mse: 99.9281 35/35 [==============================] - 0s 2ms/step - loss: 102.9592 - mae: 7.3207 - mse: 102.9592 35/35 [==============================] - 0s 2ms/step - loss: 101.1885 - mae: 7.4587 - mse: 101.1885 35/35 [==============================] - 0s 2ms/step - loss: 110.4196 - mae: 7.8511 - mse: 110.4196 35/35 [==============================] - 0s 2ms/step - loss: 106.9168 - mae: 7.5546 - mse: 106.9168 35/35 [==============================] - 0s 2ms/step - loss: 117.4184 - mae: 7.8371 - mse: 117.4184 35/35 [==============================] - 0s 2ms/step - loss: 112.1162 - mae: 7.6830 - mse: 112.1162 35/35 [==============================] - 0s 2ms/step - loss: 120.4578 - mae: 7.9016 - mse: 120.4578 35/35 [==============================] - 0s 2ms/step - loss: 99.1317 - mae: 7.2572 - mse: 99.1317 MSE2: 3117429065.261 Bias: 822492728009.764 Variance: 169426455.907

Whereas sklearn.mean_squared_error reports this: Testing set Mean Abs Error: 12.70 Testing set Mean Squared Error: 294.68 Testing set Root Mean Squared Error: 16.22

Is this just a formatting error or something else? I use a Keras sequential model and use an MSE loss function.

rasbt commented 3 years ago

Hm, not sure what's going on. Maybe the model didn't converge on one of the training (bootstrap) datasets.

When I try it on a simple example with Keras, it seems to work fine:

from mlxtend.evaluate import bias_variance_decomp from mlxtend.data import boston_housing_data from sklearn.model_selection import train_test_split from sklearn.metrics import mean_squared_error

X, y = boston_housing_data() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=123, shuffle=True)

import keras import tensorflow as tf

model = keras.Sequential([ keras.layers.Dense(32, activation=tf.nn.relu), keras.layers.Dense(1) ])

optimizer = tf.keras.optimizers.Adam() model.compile(loss='mean_squared_error', optimizer=optimizer)

model.fit(X_train, y_train, epochs=100)

Epoch 1/100 12/12 [==============================] - 0s 651us/step - loss: 10166.1113 Epoch 2/100 12/12 [==============================] - 0s 607us/step - loss: 5154.2451 Epoch 3/100 12/12 [==============================] - 0s 650us/step - loss: 3184.3140 Epoch 4/100 12/12 [==============================] - 0s 589us/step - loss: 2566.4712 Epoch 5/100 12/12 [==============================] - 0s 596us/step - loss: 2224.1023 Epoch 6/100 12/12 [==============================] - 0s 598us/step - loss: 1956.1360 Epoch 7/100 12/12 [==============================] - 0s 558us/step - loss: 1721.9098 Epoch 8/100 .... Epoch 99/100 12/12 [==============================] - 0s 620us/step - loss: 61.8690 Epoch 100/100 12/12 [==============================] - 0s 507us/step - loss: 61.4688

mean_squared_error(model.predict(X_test), y_test) 63.55154033706853

avg_expected_loss, avg_bias, avg_var = bias_variance_decomp( model, X_train, y_train, X_test, y_test, loss='mse', random_seed=123)

avg_expected_loss, avg_bias, avg_var = bias_variance_decomp( model, X_train, y_train, X_test, y_test, loss='mse', random_seed=123)

avg_expected_loss, avg_bias, avg_var (32.459421052631576, 29.777634046052633, 2.6817870065789475)

Regarding the Keras support, it was just added recently and is not in the release version, yet. Happy to add further improvements. I think one issue is that it is not reset after fitting on each training set and continues training. It would probably be good to reset it after each training set fit. I am not a Keras user and don't know the details wrt how to reset/reinitialize a model but I added it as an issue here if you want to chime in: https://github.com/rasbt/mlxtend/issues/746

Best, Sebastian

On Nov 5, 2020, at 9:10 AM, Julie notifications@github.com wrote:

I added mlextend to my python code in order to use bias_variance_decomp; I'd already been using sklearn.metrics.mean_squared_error. I noticed that the output reported for MSE, and then obviously bias and variance were very large and quite different than what sklearn.metrics.mean_squared_error reported. Here is the output:

35/35 [==============================] - 0s 2ms/step - loss: 173.9796 - mae: 9.4905 - mse: 173.9796 35/35 [==============================] - 0s 1ms/step - loss: 130.8012 - mae: 8.5116 - mse: 130.8012 35/35 [==============================] - 0s 2ms/step - loss: 160.6490 - mae: 9.0809 - mse: 160.6490 35/35 [==============================] - 0s 2ms/step - loss: 135.1470 - mae: 8.5490 - mse: 135.1470 35/35 [==============================] - 0s 2ms/step - loss: 147.8825 - mae: 8.8470 - mse: 147.8825 35/35 [==============================] - 0s 1ms/step - loss: 146.6789 - mae: 8.9765 - mse: 146.6789 35/35 [==============================] - 0s 1ms/step - loss: 158.7630 - mae: 9.0425 - mse: 158.7630 35/35 [==============================] - 0s 1ms/step - loss: 166.1496 - mae: 8.8941 - mse: 166.1496 35/35 [==============================] - 0s 1ms/step - loss: 155.4472 - mae: 9.1264 - mse: 155.4472 35/35 [==============================] - 0s 1ms/step - loss: 131.1109 - mae: 8.5182 - mse: 131.1109 35/35 [==============================] - 0s 2ms/step - loss: 183.1126 - mae: 9.8470 - mse: 183.1126 35/35 [==============================] - 0s 2ms/step - loss: 122.8194 - mae: 8.0459 - mse: 122.8194 35/35 [==============================] - 0s 2ms/step - loss: 141.6612 - mae: 8.8127 - mse: 141.6612 35/35 [==============================] - 0s 1ms/step - loss: 165.4881 - mae: 9.1495 - mse: 165.4881 35/35 [==============================] - 0s 1ms/step - loss: 165.2237 - mae: 9.1524 - mse: 165.2237 35/35 [==============================] - 0s 2ms/step - loss: 173.1414 - mae: 9.0957 - mse: 173.1414 35/35 [==============================] - 0s 2ms/step - loss: 135.0640 - mae: 8.5834 - mse: 135.0640 35/35 [==============================] - 0s 1ms/step - loss: 140.1290 - mae: 8.6145 - mse: 140.1290 35/35 [==============================] - 0s 2ms/step - loss: 131.1522 - mae: 8.2898 - mse: 131.1522 35/35 [==============================] - 0s 2ms/step - loss: 157.1143 - mae: 8.9523 - mse: 157.1143 35/35 [==============================] - 0s 2ms/step - loss: 138.1463 - mae: 8.3738 - mse: 138.1463 35/35 [==============================] - 0s 2ms/step - loss: 137.5353 - mae: 8.4970 - mse: 137.5353 35/35 [==============================] - 0s 2ms/step - loss: 156.1401 - mae: 9.0615 - mse: 156.1401 35/35 [==============================] - 0s 2ms/step - loss: 155.3580 - mae: 8.9183 - mse: 155.3580 35/35 [==============================] - 0s 2ms/step - loss: 136.6704 - mae: 8.6769 - mse: 136.6704 35/35 [==============================] - 0s 2ms/step - loss: 143.1792 - mae: 8.7125 - mse: 143.1792 35/35 [==============================] - 0s 2ms/step - loss: 152.4702 - mae: 8.9602 - mse: 152.4702 35/35 [==============================] - 0s 2ms/step - loss: 142.4001 - mae: 8.6808 - mse: 142.4001 35/35 [==============================] - 0s 2ms/step - loss: 148.8286 - mae: 8.6396 - mse: 148.8286 35/35 [==============================] - 0s 2ms/step - loss: 141.8189 - mae: 8.7887 - mse: 141.8189 35/35 [==============================] - 0s 2ms/step - loss: 136.6845 - mae: 8.6267 - mse: 136.6845 35/35 [==============================] - 0s 2ms/step - loss: 155.9441 - mae: 8.8864 - mse: 155.9441 35/35 [==============================] - 0s 2ms/step - loss: 157.9773 - mae: 9.1069 - mse: 157.9773 35/35 [==============================] - 0s 1ms/step - loss: 136.8946 - mae: 8.5628 - mse: 136.8946 35/35 [==============================] - 0s 2ms/step - loss: 153.9737 - mae: 9.1242 - mse: 153.9737 35/35 [==============================] - 0s 1ms/step - loss: 132.3296 - mae: 8.3255 - mse: 132.3296 35/35 [==============================] - 0s 2ms/step - loss: 118.2071 - mae: 8.1619 - mse: 118.2071 35/35 [==============================] - 0s 2ms/step - loss: 148.9615 - mae: 8.7149 - mse: 148.9615 35/35 [==============================] - 0s 2ms/step - loss: 161.5860 - mae: 9.1059 - mse: 161.5860 35/35 [==============================] - 0s 1ms/step - loss: 152.3022 - mae: 8.9790 - mse: 152.3022 35/35 [==============================] - 0s 2ms/step - loss: 142.1344 - mae: 8.5953 - mse: 142.1344 35/35 [==============================] - 0s 2ms/step - loss: 142.9644 - mae: 8.6443 - mse: 142.9644 35/35 [==============================] - 0s 2ms/step - loss: 127.6746 - mae: 8.4494 - mse: 127.6746 35/35 [==============================] - 0s 1ms/step - loss: 132.8535 - mae: 8.3585 - mse: 132.8535 35/35 [==============================] - 0s 1ms/step - loss: 126.8244 - mae: 8.3929 - mse: 126.8244 35/35 [==============================] - 0s 1ms/step - loss: 162.6216 - mae: 9.0103 - mse: 162.6216 35/35 [==============================] - 0s 1ms/step - loss: 109.4171 - mae: 8.0277 - mse: 109.4171 35/35 [==============================] - 0s 1ms/step - loss: 127.3269 - mae: 8.4461 - mse: 127.3269 35/35 [==============================] - 0s 1ms/step - loss: 147.7464 - mae: 8.9359 - mse: 147.7464 35/35 [==============================] - 0s 2ms/step - loss: 122.6896 - mae: 8.1013 - mse: 122.6896 35/35 [==============================] - 0s 1ms/step - loss: 123.9316 - mae: 8.2173 - mse: 123.9316 35/35 [==============================] - 0s 2ms/step - loss: 137.0827 - mae: 8.6589 - mse: 137.0827 35/35 [==============================] - 0s 1ms/step - loss: 136.7113 - mae: 8.5353 - mse: 136.7113 35/35 [==============================] - 0s 1ms/step - loss: 139.3192 - mae: 8.5043 - mse: 139.3192 35/35 [==============================] - 0s 2ms/step - loss: 126.8270 - mae: 8.3104 - mse: 126.8270 35/35 [==============================] - 0s 1ms/step - loss: 140.4781 - mae: 8.3527 - mse: 140.4781 35/35 [==============================] - 0s 1ms/step - loss: 128.6389 - mae: 8.1751 - mse: 128.6389 35/35 [==============================] - 0s 2ms/step - loss: 122.3560 - mae: 8.2169 - mse: 122.3560 35/35 [==============================] - 0s 1ms/step - loss: 143.7096 - mae: 8.5355 - mse: 143.7096 35/35 [==============================] - 0s 2ms/step - loss: 137.4444 - mae: 8.5550 - mse: 137.4444 35/35 [==============================] - 0s 1ms/step - loss: 124.1772 - mae: 8.2075 - mse: 124.1772 35/35 [==============================] - 0s 2ms/step - loss: 133.5365 - mae: 8.5046 - mse: 133.5365 35/35 [==============================] - 0s 2ms/step - loss: 139.3011 - mae: 8.4549 - mse: 139.3011 35/35 [==============================] - 0s 1ms/step - loss: 149.2839 - mae: 8.9334 - mse: 149.2839 35/35 [==============================] - 0s 2ms/step - loss: 164.9708 - mae: 9.1750 - mse: 164.9708 35/35 [==============================] - 0s 2ms/step - loss: 109.8826 - mae: 7.6950 - mse: 109.8826 35/35 [==============================] - 0s 1ms/step - loss: 133.8806 - mae: 8.7409 - mse: 133.8806 35/35 [==============================] - 0s 2ms/step - loss: 108.4042 - mae: 7.8433 - mse: 108.4042 35/35 [==============================] - 0s 1ms/step - loss: 117.1420 - mae: 7.9659 - mse: 117.1420 35/35 [==============================] - 0s 2ms/step - loss: 151.9397 - mae: 8.8675 - mse: 151.9397 35/35 [==============================] - 0s 1ms/step - loss: 120.8102 - mae: 8.1778 - mse: 120.8102 35/35 [==============================] - 0s 2ms/step - loss: 121.4617 - mae: 8.1021 - mse: 121.4617 35/35 [==============================] - 0s 2ms/step - loss: 133.6543 - mae: 8.4887 - mse: 133.6543 35/35 [==============================] - 0s 2ms/step - loss: 128.9574 - mae: 8.3054 - mse: 128.9574 35/35 [==============================] - 0s 2ms/step - loss: 120.7032 - mae: 8.2093 - mse: 120.7032 35/35 [==============================] - 0s 2ms/step - loss: 129.3289 - mae: 8.2995 - mse: 129.3289 35/35 [==============================] - 0s 2ms/step - loss: 133.5342 - mae: 8.5464 - mse: 133.5342 35/35 [==============================] - 0s 2ms/step - loss: 116.0511 - mae: 7.9428 - mse: 116.0511 35/35 [==============================] - 0s 1ms/step - loss: 126.8614 - mae: 8.3553 - mse: 126.8614 35/35 [==============================] - 0s 2ms/step - loss: 136.2515 - mae: 8.3411 - mse: 136.2515 35/35 [==============================] - 0s 2ms/step - loss: 126.8929 - mae: 8.1561 - mse: 126.8929 35/35 [==============================] - 0s 1ms/step - loss: 117.1917 - mae: 7.9982 - mse: 117.1917 35/35 [==============================] - 0s 2ms/step - loss: 147.9106 - mae: 8.4827 - mse: 147.9106 35/35 [==============================] - 0s 2ms/step - loss: 140.1210 - mae: 8.5273 - mse: 140.1210 35/35 [==============================] - 0s 1ms/step - loss: 133.0966 - mae: 8.4887 - mse: 133.0966 35/35 [==============================] - 0s 1ms/step - loss: 129.1837 - mae: 8.2691 - mse: 129.1837 35/35 [==============================] - 0s 2ms/step - loss: 137.9485 - mae: 8.3385 - mse: 137.9485 35/35 [==============================] - 0s 2ms/step - loss: 131.4005 - mae: 8.3554 - mse: 131.4005 35/35 [==============================] - 0s 2ms/step - loss: 122.8001 - mae: 8.0379 - mse: 122.8001 35/35 [==============================] - 0s 1ms/step - loss: 129.1283 - mae: 8.2210 - mse: 129.1283 35/35 [==============================] - 0s 2ms/step - loss: 127.9801 - mae: 8.2486 - mse: 127.9801 35/35 [==============================] - 0s 2ms/step - loss: 142.2471 - mae: 8.7803 - mse: 142.2471 35/35 [==============================] - 0s 2ms/step - loss: 119.0921 - mae: 8.0017 - mse: 119.0921 35/35 [==============================] - 0s 1ms/step - loss: 147.3405 - mae: 8.8006 - mse: 147.3405 35/35 [==============================] - 0s 1ms/step - loss: 112.7579 - mae: 7.9488 - mse: 112.7579 35/35 [==============================] - 0s 2ms/step - loss: 120.9991 - mae: 8.1135 - mse: 120.9991 35/35 [==============================] - 0s 2ms/step - loss: 119.6753 - mae: 7.9743 - mse: 119.6753 35/35 [==============================] - 0s 1ms/step - loss: 123.2571 - mae: 8.0348 - mse: 123.2571 35/35 [==============================] - 0s 1ms/step - loss: 122.3880 - mae: 8.0781 - mse: 122.3880 35/35 [==============================] - 0s 1ms/step - loss: 127.3674 - mae: 8.4713 - mse: 127.3674 35/35 [==============================] - 0s 2ms/step - loss: 121.2972 - mae: 8.2487 - mse: 121.2972 35/35 [==============================] - 0s 1ms/step - loss: 125.6018 - mae: 8.1503 - mse: 125.6018 35/35 [==============================] - 0s 2ms/step - loss: 127.3179 - mae: 8.2881 - mse: 127.3179 35/35 [==============================] - 0s 2ms/step - loss: 124.0926 - mae: 8.1012 - mse: 124.0926 35/35 [==============================] - 0s 2ms/step - loss: 127.1927 - mae: 8.3049 - mse: 127.1927 35/35 [==============================] - 0s 2ms/step - loss: 122.4894 - mae: 7.9466 - mse: 122.4894 35/35 [==============================] - 0s 2ms/step - loss: 121.0909 - mae: 8.0331 - mse: 121.0909 35/35 [==============================] - 0s 2ms/step - loss: 118.8251 - mae: 7.9954 - mse: 118.8251 35/35 [==============================] - 0s 2ms/step - loss: 129.3154 - mae: 8.3126 - mse: 129.3154 35/35 [==============================] - 0s 2ms/step - loss: 131.7466 - mae: 8.1973 - mse: 131.7466 35/35 [==============================] - 0s 2ms/step - loss: 123.7850 - mae: 8.2048 - mse: 123.7850 35/35 [==============================] - 0s 2ms/step - loss: 124.8050 - mae: 8.0557 - mse: 124.8050 35/35 [==============================] - 0s 2ms/step - loss: 112.5697 - mae: 7.9108 - mse: 112.5697 35/35 [==============================] - 0s 2ms/step - loss: 130.0370 - mae: 7.8804 - mse: 130.0370 35/35 [==============================] - 0s 1ms/step - loss: 116.7593 - mae: 7.8749 - mse: 116.7593 35/35 [==============================] - 0s 2ms/step - loss: 129.5425 - mae: 8.3358 - mse: 129.5425 35/35 [==============================] - 0s 2ms/step - loss: 135.8403 - mae: 8.4233 - mse: 135.8403 35/35 [==============================] - 0s 2ms/step - loss: 109.8812 - mae: 7.9081 - mse: 109.8812 35/35 [==============================] - 0s 2ms/step - loss: 149.4199 - mae: 8.4815 - mse: 149.4199 35/35 [==============================] - 0s 1ms/step - loss: 124.6325 - mae: 8.3520 - mse: 124.6325 35/35 [==============================] - 0s 2ms/step - loss: 126.1257 - mae: 8.2450 - mse: 126.1257 35/35 [==============================] - 0s 2ms/step - loss: 117.5301 - mae: 7.9030 - mse: 117.5301 35/35 [==============================] - 0s 2ms/step - loss: 115.2620 - mae: 7.8474 - mse: 115.2620 35/35 [==============================] - 0s 2ms/step - loss: 121.3159 - mae: 8.1452 - mse: 121.3159 35/35 [==============================] - 0s 1ms/step - loss: 117.3415 - mae: 7.8663 - mse: 117.3415 35/35 [==============================] - 0s 1ms/step - loss: 129.4238 - mae: 8.3948 - mse: 129.4238 35/35 [==============================] - 0s 2ms/step - loss: 130.7375 - mae: 8.1282 - mse: 130.7375 35/35 [==============================] - 0s 2ms/step - loss: 127.7135 - mae: 8.1718 - mse: 127.7135 35/35 [==============================] - 0s 2ms/step - loss: 116.2874 - mae: 7.9404 - mse: 116.2874 35/35 [==============================] - 0s 1ms/step - loss: 105.4222 - mae: 7.8594 - mse: 105.4222 35/35 [==============================] - 0s 2ms/step - loss: 124.3329 - mae: 8.2165 - mse: 124.3329 35/35 [==============================] - 0s 2ms/step - loss: 122.6872 - mae: 8.0995 - mse: 122.6872 35/35 [==============================] - 0s 2ms/step - loss: 128.9181 - mae: 8.1528 - mse: 128.9181 35/35 [==============================] - 0s 2ms/step - loss: 111.9773 - mae: 7.7923 - mse: 111.9773 35/35 [==============================] - 0s 2ms/step - loss: 131.5503 - mae: 8.4614 - mse: 131.5503 35/35 [==============================] - 0s 2ms/step - loss: 108.5293 - mae: 7.9453 - mse: 108.5293 35/35 [==============================] - 0s 2ms/step - loss: 116.3819 - mae: 8.0942 - mse: 116.3819 35/35 [==============================] - 0s 2ms/step - loss: 119.4842 - mae: 7.9115 - mse: 119.4842 35/35 [==============================] - 0s 2ms/step - loss: 107.9299 - mae: 7.7129 - mse: 107.9299 35/35 [==============================] - 0s 2ms/step - loss: 120.8085 - mae: 8.0006 - mse: 120.8085 35/35 [==============================] - 0s 1ms/step - loss: 112.6447 - mae: 7.6067 - mse: 112.6447 35/35 [==============================] - 0s 2ms/step - loss: 119.1002 - mae: 8.0175 - mse: 119.1002 35/35 [==============================] - 0s 2ms/step - loss: 106.1602 - mae: 7.6995 - mse: 106.1602 35/35 [==============================] - 0s 2ms/step - loss: 108.9758 - mae: 7.8350 - mse: 108.9758 35/35 [==============================] - 0s 2ms/step - loss: 126.9872 - mae: 8.0682 - mse: 126.9872 35/35 [==============================] - 0s 2ms/step - loss: 128.8872 - mae: 8.3490 - mse: 128.8872 35/35 [==============================] - 0s 2ms/step - loss: 111.4356 - mae: 7.6313 - mse: 111.4356 35/35 [==============================] - 0s 2ms/step - loss: 119.5411 - mae: 7.9536 - mse: 119.5411 35/35 [==============================] - 0s 2ms/step - loss: 120.2521 - mae: 8.1189 - mse: 120.2521 35/35 [==============================] - 0s 2ms/step - loss: 116.5888 - mae: 7.7222 - mse: 116.5888 35/35 [==============================] - 0s 2ms/step - loss: 126.2040 - mae: 8.1151 - mse: 126.2040 35/35 [==============================] - 0s 2ms/step - loss: 115.2144 - mae: 7.9930 - mse: 115.2144 35/35 [==============================] - 0s 2ms/step - loss: 111.7060 - mae: 7.8256 - mse: 111.7060 35/35 [==============================] - 0s 2ms/step - loss: 95.0104 - mae: 7.3917 - mse: 95.0104 35/35 [==============================] - 0s 1ms/step - loss: 112.3655 - mae: 7.8621 - mse: 112.3655 35/35 [==============================] - 0s 1ms/step - loss: 109.9810 - mae: 7.7577 - mse: 109.9810 35/35 [==============================] - 0s 2ms/step - loss: 117.0792 - mae: 7.7390 - mse: 117.0792 35/35 [==============================] - 0s 2ms/step - loss: 121.6713 - mae: 7.9948 - mse: 121.6713 35/35 [==============================] - 0s 2ms/step - loss: 124.0231 - mae: 7.8831 - mse: 124.0231 35/35 [==============================] - 0s 2ms/step - loss: 115.3666 - mae: 7.8847 - mse: 115.3666 35/35 [==============================] - 0s 2ms/step - loss: 127.0129 - mae: 8.2836 - mse: 127.0129 35/35 [==============================] - 0s 2ms/step - loss: 108.1284 - mae: 7.6674 - mse: 108.1284 35/35 [==============================] - 0s 2ms/step - loss: 110.8054 - mae: 7.6269 - mse: 110.8054 35/35 [==============================] - 0s 2ms/step - loss: 106.3360 - mae: 7.4058 - mse: 106.3360 35/35 [==============================] - 0s 1ms/step - loss: 114.2057 - mae: 7.9692 - mse: 114.2057 35/35 [==============================] - 0s 1ms/step - loss: 118.5681 - mae: 8.0258 - mse: 118.5681 35/35 [==============================] - 0s 2ms/step - loss: 104.6698 - mae: 7.6472 - mse: 104.6698 35/35 [==============================] - 0s 1ms/step - loss: 104.0010 - mae: 7.4388 - mse: 104.0010 35/35 [==============================] - 0s 2ms/step - loss: 111.3144 - mae: 7.6359 - mse: 111.3144 35/35 [==============================] - 0s 2ms/step - loss: 123.0542 - mae: 7.9958 - mse: 123.0542 35/35 [==============================] - 0s 1ms/step - loss: 100.5310 - mae: 7.4520 - mse: 100.5310 35/35 [==============================] - 0s 2ms/step - loss: 117.6339 - mae: 8.0745 - mse: 117.6339 35/35 [==============================] - 0s 2ms/step - loss: 113.0730 - mae: 7.6076 - mse: 113.0730 35/35 [==============================] - 0s 2ms/step - loss: 124.0812 - mae: 8.0376 - mse: 124.0812 35/35 [==============================] - 0s 2ms/step - loss: 107.7094 - mae: 7.6370 - mse: 107.7094 35/35 [==============================] - 0s 2ms/step - loss: 103.8286 - mae: 7.4406 - mse: 103.8286 35/35 [==============================] - 0s 2ms/step - loss: 109.1548 - mae: 7.7314 - mse: 109.1548 35/35 [==============================] - 0s 2ms/step - loss: 118.4798 - mae: 7.8871 - mse: 118.4798 35/35 [==============================] - 0s 1ms/step - loss: 126.4193 - mae: 8.0505 - mse: 126.4193 35/35 [==============================] - 0s 1ms/step - loss: 100.0260 - mae: 7.3638 - mse: 100.0260 35/35 [==============================] - 0s 1ms/step - loss: 105.5283 - mae: 7.6150 - mse: 105.5283 35/35 [==============================] - 0s 1ms/step - loss: 102.6057 - mae: 7.6035 - mse: 102.6057 35/35 [==============================] - 0s 2ms/step - loss: 105.3562 - mae: 7.3650 - mse: 105.3562 35/35 [==============================] - 0s 2ms/step - loss: 112.6658 - mae: 7.8923 - mse: 112.6658 35/35 [==============================] - 0s 1ms/step - loss: 98.4838 - mae: 7.3733 - mse: 98.4838 35/35 [==============================] - 0s 2ms/step - loss: 109.5082 - mae: 7.5858 - mse: 109.5082 35/35 [==============================] - 0s 2ms/step - loss: 106.8641 - mae: 7.5266 - mse: 106.8641 35/35 [==============================] - 0s 1ms/step - loss: 119.8303 - mae: 7.7595 - mse: 119.8303 35/35 [==============================] - 0s 2ms/step - loss: 114.1899 - mae: 7.6324 - mse: 114.1899 35/35 [==============================] - 0s 1ms/step - loss: 110.3103 - mae: 7.5866 - mse: 110.3103 35/35 [==============================] - 0s 1ms/step - loss: 104.0811 - mae: 7.6450 - mse: 104.0811 35/35 [==============================] - 0s 2ms/step - loss: 99.9281 - mae: 7.5091 - mse: 99.9281 35/35 [==============================] - 0s 2ms/step - loss: 102.9592 - mae: 7.3207 - mse: 102.9592 35/35 [==============================] - 0s 2ms/step - loss: 101.1885 - mae: 7.4587 - mse: 101.1885 35/35 [==============================] - 0s 2ms/step - loss: 110.4196 - mae: 7.8511 - mse: 110.4196 35/35 [==============================] - 0s 2ms/step - loss: 106.9168 - mae: 7.5546 - mse: 106.9168 35/35 [==============================] - 0s 2ms/step - loss: 117.4184 - mae: 7.8371 - mse: 117.4184 35/35 [==============================] - 0s 2ms/step - loss: 112.1162 - mae: 7.6830 - mse: 112.1162 35/35 [==============================] - 0s 2ms/step - loss: 120.4578 - mae: 7.9016 - mse: 120.4578 35/35 [==============================] - 0s 2ms/step - loss: 99.1317 - mae: 7.2572 - mse: 99.1317 MSE2: 3117429065.261 Bias: 822492728009.764 Variance: 169426455.907

Whereas sklearn.mean_squared_error reports this: Testing set Mean Abs Error: 12.70 Testing set Mean Squared Error: 294.68 Testing set Root Mean Squared Error: 16.22

Is this just a formatting error or something else? I use a Keras sequential model and use an MSE loss function.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub, or unsubscribe.

jbielski commented 3 years ago

In your output from bias_variance_decomp, is 'avg_expected_loss' the same as 'mean_squared_error'? The first says 32.45... and the second says 63.55.

rasbt commented 3 years ago

'avg_expected_loss' the same as 'mean_squared_error'

Yes, they should be roughly the same because they are on the same scale. Avg expected loss here is the expectation of the squared error loss estimated from bootstrap samples, which is on the same scale of the mean squared error loss.

The reason why the first one is so much higher than the second is that the model gets fit multiple times, so it improves. I.e., the model with MSE 63.55 gets fit to the bootstrap datasets and continues to learn because it hasn't converged initially.

rasbt commented 3 years ago

Btw I think this

    for i in range(num_rounds):
        X_boot, y_boot = _draw_bootstrap_sample(rng, X_train, y_train)
        if estimator.__class__.__name__ == 'Sequential':

            cloned = tf.keras.models.clone_model(estimator)
            cloned.compile(loss='mean_squared_error')
            cloned.fit(X_boot, y_boot)
            pred = cloned.predict(X_test).reshape(1, -1)

might be a fairer way for this comparison.

jbielski commented 3 years ago

I patched the code with the above and now I see this: 35/35 [==============================] - 0s 1ms/step - loss: 1137.4727 35/35 [==============================] - 0s 1ms/step - loss: 1390.6079 35/35 [==============================] - 0s 1ms/step - loss: 1599.6768 ... 35/35 [==============================] - 0s 1ms/step - loss: 1360.8937 35/35 [==============================] - 0s 1ms/step - loss: 1537.2712 35/35 [==============================] - 0s 1ms/step - loss: 1405.0632 35/35 [==============================] - 0s 1ms/step - loss: 1256.0490 35/35 [==============================] - 0s 913us/step - loss: 1387.0649 35/35 [==============================] - 0s 2ms/step - loss: 1373.4651 35/35 [==============================] - 0s 1ms/step - loss: 1192.3810 MSE2: 2278441614.916 <---- Order of magnitude high Bias: 626303296783.733 Variance: 33626931.103

and the sklearn metrics code says: Testing set Mean Abs Error: 12.97 Testing set Mean Squared Error: 308.48 Testing set Root Mean Squared Error: 17.56

rasbt commented 3 years ago

I think the problem is the stability during training. By default it runs only 1 epoch. I changed it to 50 and it seems much more stable

    for i in range(num_rounds):
        X_boot, y_boot = _draw_bootstrap_sample(rng, X_train, y_train)
        if estimator.__class__.__name__ == 'Sequential':

            cloned = tf.keras.models.clone_model(estimator)
            cloned.compile(loss='mean_squared_error')
            cloned.fit(X_boot, y_boot, epochs=50, verbose=0)
            pred = cloned.predict(X_test).reshape(1, -1)

for the full running code, I get

# Sebastian Raschka 2014-2020
# mlxtend Machine Learning Library Extensions
#
# Nonparametric Permutation Test
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause
import numpy as np

def _draw_bootstrap_sample(rng, X, y):
    sample_indices = np.arange(X.shape[0])
    bootstrap_indices = rng.choice(sample_indices,
                                   size=sample_indices.shape[0],
                                   replace=True)
    return X[bootstrap_indices], y[bootstrap_indices]

def bias_variance_decomp(estimator, X_train, y_train, X_test, y_test,
                         loss='0-1_loss', num_rounds=200, random_seed=None):
    """
    estimator : object
        A classifier or regressor object or class implementing both a
        `fit` and `predict` method similar to the scikit-learn API.

    X_train : array-like, shape=(num_examples, num_features)
        A training dataset for drawing the bootstrap samples to carry
        out the bias-variance decomposition.

    y_train : array-like, shape=(num_examples)
        Targets (class labels, continuous values in case of regression)
        associated with the `X_train` examples.

    X_test : array-like, shape=(num_examples, num_features)
        The test dataset for computing the average loss, bias,
        and variance.

    y_test : array-like, shape=(num_examples)
        Targets (class labels, continuous values in case of regression)
        associated with the `X_test` examples.

    loss : str (default='0-1_loss')
        Loss function for performing the bias-variance decomposition.
        Currently allowed values are '0-1_loss' and 'mse'.

    num_rounds : int (default=200)
        Number of bootstrap rounds for performing the bias-variance
        decomposition.

    random_seed : int (default=None)
        Random seed for the bootstrap sampling used for the
        bias-variance decomposition.

    Returns
    ----------
    avg_expected_loss, avg_bias, avg_var : returns the average expected
        average bias, and average bias (all floats), where the average
        is computed over the data points in the test set.

    Examples
    -----------
    For usage examples, please see
    http://rasbt.github.io/mlxtend/user_guide/evaluate/bias_variance_decomp/

    """
    supported = ['0-1_loss', 'mse']
    if loss not in supported:
        raise NotImplementedError('loss must be one of the following: %s' %
                                  supported)

    rng = np.random.RandomState(random_seed)

    all_pred = np.zeros((num_rounds, y_test.shape[0]), dtype=np.int)

    for i in range(num_rounds):
        X_boot, y_boot = _draw_bootstrap_sample(rng, X_train, y_train)
        if estimator.__class__.__name__ == 'Sequential':

            cloned = tf.keras.models.clone_model(estimator)
            cloned.compile(loss='mean_squared_error')
            cloned.fit(X_boot, y_boot, epochs=50, verbose=0)
            pred = cloned.predict(X_test).reshape(1, -1)
        else:
            pred = estimator.fit(X_boot, y_boot).predict(X_test)
        all_pred[i] = pred

    if loss == '0-1_loss':
        main_predictions = np.apply_along_axis(lambda x:
                                               np.argmax(np.bincount(x)),
                                               axis=0,
                                               arr=all_pred)

        avg_expected_loss = np.apply_along_axis(lambda x:
                                                (x != y_test).mean(),
                                                axis=1,
                                                arr=all_pred).mean()

        avg_bias = np.sum(main_predictions != y_test) / y_test.size

        var = np.zeros(pred.shape)

        for pred in all_pred:
            var += (pred != main_predictions).astype(np.int)
        var /= num_rounds

        avg_var = var.sum()/y_test.shape[0]

    else:
        avg_expected_loss = np.apply_along_axis(
            lambda x:
            ((x - y_test)**2).mean(),
            axis=1,
            arr=all_pred).mean()

        main_predictions = np.mean(all_pred, axis=0)

        avg_bias = np.sum((main_predictions - y_test)**2) / y_test.size
        avg_var = np.sum((main_predictions - all_pred)**2) / all_pred.size

    return avg_expected_loss, avg_bias, avg_var

from sklearn.tree import DecisionTreeRegressor
from mlxtend.data import boston_housing_data
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

X, y = boston_housing_data()
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    test_size=0.3,
                                                    random_state=123,
                                                    shuffle=True)

import keras
import tensorflow as tf

model = keras.Sequential([
    keras.layers.Dense(32, activation=tf.nn.relu),
    keras.layers.Dense(1)
  ])

optimizer = tf.keras.optimizers.Adam()
model.compile(loss='mean_squared_error', optimizer=optimizer)

#model.fit(X_train, y_train, epochs=100)
avg_expected_loss, avg_bias, avg_var = bias_variance_decomp(
        model, X_train, y_train, X_test, y_test, 
        loss='mse',
        random_seed=123)

avg_expected_loss, avg_bias, avg_var

(68.65805921052632, 33.903210526315796, 34.75484868421052)

which is in the ballpark of

model = keras.Sequential([
    keras.layers.Dense(32, activation=tf.nn.relu),
    keras.layers.Dense(1)
  ])

optimizer = tf.keras.optimizers.Adam()
model.compile(loss='mean_squared_error', optimizer=optimizer)
model.fit(X_train, y_train, epochs=50)

mean_squared_error(model.predict(X_test), y_test)
...
Epoch 49/50
12/12 [==============================] - 0s 523us/step - loss: 58.0166
Epoch 50/50
12/12 [==============================] - 0s 542us/step - loss: 56.4809
[3]:
57.8261745269078
rasbt commented 3 years ago

Hi Julie,

the recent PR (#748) should have fixed the issue. It should be more consistent now. Can you give it a try and see if that helped?

As shown in the newly added Example 3 in the user guide (https://github.com/rasbt/mlxtend/blob/master/docs/sources/user_guide/evaluate/bias_variance_decomp.ipynb), I recommend providing the number of training epochs as fit_param to the function, where the num_epochs value should be the same as the one you used when training the network before.

avg_expected_loss, avg_bias, avg_var = bias_variance_decomp(
        model, X_train, y_train, X_test, y_test, 
        loss='mse',
        num_rounds=100,
        random_seed=123,
        epochs=200, # fit_param
        verbose=0) # fit_param