JuliaAI / DecisionTree.jl

Julia implementation of Decision Tree (CART) and Random Forest algorithms
Other
356 stars 102 forks source link

Use seed! to put every copy of rng into a unique state #198

Closed dhanak closed 1 year ago

dhanak commented 1 year ago

Fixes #194.

Using rand(_rng, i) didn't really put all copies of rng into a unique state, the states were still interlocked (all the generators produced same sequence of random numbers with some offset). Calling seed! with a deterministic, pseudo-random seed for each thread produces much better results, which is also visible in the classification and regression accuracies produced by the tests.

Here is the diff output of the unit tests before and after the change:

89c89
< Mean Accuracy: 0.8748748748748749
---
> Mean Accuracy: 0.8998998998998999
91c91
< Mean Accuracy: 0.8448448448448449
---
> Mean Accuracy: 0.9019019019019018
93c93
< Mean Accuracy: 0.8748748748748749
---
> Mean Accuracy: 0.8998998998998999
95c95
< Mean Accuracy: 0.8748748748748749
---
> Mean Accuracy: 0.8998998998998999
97c97
< Mean Accuracy: 0.8608608608608609
---
> Mean Accuracy: 0.9039039039039038
99c99
< Mean Accuracy: 0.8608608608608609
---
> Mean Accuracy: 0.9039039039039038
101c101
< Mean Accuracy: 0.8358358358358359
---
> Mean Accuracy: 0.908908908908909
103c103
< Mean Accuracy: 0.8358358358358359
---
> Mean Accuracy: 0.908908908908909
105c105
< Mean Accuracy: 0.8818818818818818
---
> Mean Accuracy: 0.8988988988988988
107c107
< Mean Accuracy: 0.8818818818818818
---
> Mean Accuracy: 0.8988988988988988
109c109
< Mean Accuracy: 0.8788788788788789
---
> Mean Accuracy: 0.9009009009009009
111c111
< Mean Accuracy: 0.8788788788788789
---
> Mean Accuracy: 0.9009009009009009
116,119c116,119
<  6    3    0  0
<  0  121   29  0
<  0   26  136  7
<  0    0    4  1
---
>  5    4    0  0
>  0  128   22  0
>  0    3  165  1
>  0    0    5  0
121,122c121,122
< Accuracy: 0.7927927927927928
< Kappa:    0.6153446948136739
---
> Accuracy: 0.8948948948948949
> Kappa:    0.7995390516158992
127,130c127,130
<  9    1    0  0
<  5  123   15  0
<  0   16  156  2
<  0    0    6  0
---
>  5    5    0  0
>  0  132   11  0
>  0   12  162  0
>  0    0    4  2
132,133c132,133
< Accuracy: 0.8648648648648649
< Kappa:    0.7499123817153157
---
> Accuracy: 0.9039039039039038
> Kappa:    0.818534791049351
138,141c138,141
<  7    3    0  0
<  3  121   33  0
<  0   11  144  4
<  0    0    6  1
---
>  1    9    0  0
>  0  140   17  0
>  0   10  149  0
>  0    0    7  0
143,144c143,144
< Accuracy: 0.8198198198198198
< Kappa:    0.6695445072938374
---
> Accuracy: 0.8708708708708709
> Kappa:    0.754849423890154
146c146
< Mean Accuracy: 0.8258258258258259
---
> Mean Accuracy: 0.8898898898898899
220,223c220,223
<  14    9    0   0
<   2  130    7   0
<   0   10  140   2
<   0    1    2  16
---
>  12   11    0   0
>   0  133    6   0
>   0    4  148   0
>   0    0    7  12
225,226c225,226
< Accuracy: 0.9009009009009009
< Kappa:    0.83520043190714
---
> Accuracy: 0.9159159159159159
> Kappa:    0.8573024594052737
231,234c231,234
<  10   13    0   0
<   1  139   16   0
<   0   13  120   1
<   0    0   10  10
---
>  14    9    0   0
>   0  140   16   0
>   0    4  128   2
>   0    0    3  17
236,237c236,237
< Accuracy: 0.8378378378378378
< Kappa:    0.7238297088094361
---
> Accuracy: 0.8978978978978979
> Kappa:    0.8300535867068942
242,245c242,245
<  16    1    0   0
<   1  126   10   0
<   0    1  150   0
<   0    0    7  21
---
>  14    3    0   0
>   0  132    5   0
>   0    2  145   4
>   0    0    4  24
247,248c247,248
< Accuracy: 0.93993993993994
< Kappa:    0.9009797945256397
---
> Accuracy: 0.9459459459459459
> Kappa:    0.911650256470727
250c250
< Mean Accuracy: 0.8928928928928929
---
> Mean Accuracy: 0.91991991991992
310,312c310,312
<         ├─ Feature 2 < 3.1 ?
<             ├─ Iris-virginica : 2/2
<             └─ Iris-versicolor : 1/1
---
>         ├─ Feature 1 < 5.95 ?
>             ├─ Iris-versicolor : 1/1
>             └─ Iris-virginica : 2/2
355,356c355,356
<   0  20  1
<   0   1  8
---
>   0  19  2
>   0   0  9
359c359
< Kappa:    0.9366286438529784
---
> Kappa:    0.9375780274656679
375,376c375,376
<   0  12   1
<   0   7  15
---
>   0  13   0
>   0   6  16
378,379c378,379
< Accuracy: 0.84
< Kappa:    0.7613365155131264
---
> Accuracy: 0.88
> Kappa:    0.8210023866348449
381c381
< Mean Accuracy: 0.9066666666666666
---
> Mean Accuracy: 0.9199999999999999
426c426
< Mean Accuracy: 0.8270217144261188
---
> Mean Accuracy: 0.8444669676587119
463,465c463,465
< Mean Squared Error:     2.0183096134238294
< Correlation Coeff:      0.8903914722230327
< Coeff of Determination: 0.7924911697044006
---
> Mean Squared Error:     1.2353634596621743
> Correlation Coeff:      0.9467692443993198
> Coeff of Determination: 0.8729883538187402
468,470c468,470
< Mean Squared Error:     1.9714838724549328
< Correlation Coeff:      0.910241766877058
< Coeff of Determination: 0.8011434924520122
---
> Mean Squared Error:     1.3297177364601998
> Correlation Coeff:      0.9564527935419953
> Coeff of Determination: 0.8658761409152053
473,475c473,475
< Mean Squared Error:     1.6739772387561769
< Correlation Coeff:      0.9029059136519314
< Coeff of Determination: 0.813068012307753
---
> Mean Squared Error:     1.1170134745442588
> Correlation Coeff:      0.9507514465365866
> Coeff of Determination: 0.8752638063163086
477c477
< Mean Coeff of Determination: 0.8022342248213886
---
> Mean Coeff of Determination: 0.8713761003500847
488c488
< Mean Coeff of Determination: 0.5825527898815513
---
> Mean Coeff of Determination: 0.6324059967649163
ablaom commented 1 year ago

Thanks @dhanak for this valuable contribution. @rikhuijzer I really think you are in the best position to review this PR, if you don't mind?

rikhuijzer commented 1 year ago

I think it's safe to assume that Random.seed!(rng, a) is not correlated with Random.seed!(rng, b) when a != b. That means that it should be safe to drop the shared_seed.

Below are the accuracy comparisons of what is currently the dev branch of DecisionTree versus one where

_rng = Random.seed!(copy(rng), i)
diff --git a/tree-old.txt b/tree-new.txt
index 5086247..f7fe469 100644
--- a/tree-old.txt
+++ b/tree-new.txt
@@ -48,64 +48,64 @@ Mean Accuracy: 0.8688688688688688
 ##### nfoldCV Classification Forest #####
 Testing nfoldCV_forest

-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999

-Mean Accuracy: 0.8448448448448449
+Mean Accuracy: 0.908908908908909

-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999

-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999

-Mean Accuracy: 0.8608608608608609
+Mean Accuracy: 0.913913913913914

-Mean Accuracy: 0.8608608608608609
+Mean Accuracy: 0.913913913913914

-Mean Accuracy: 0.8358358358358359
+Mean Accuracy: 0.9059059059059059

-Mean Accuracy: 0.8358358358358359
+Mean Accuracy: 0.9059059059059059

-Mean Accuracy: 0.8818818818818818
+Mean Accuracy: 0.9089089089089089

-Mean Accuracy: 0.8818818818818818
+Mean Accuracy: 0.9089089089089089

-Mean Accuracy: 0.8788788788788789
+Mean Accuracy: 0.9019019019019018

-Mean Accuracy: 0.8788788788788789
+Mean Accuracy: 0.9019019019019018

 Fold 1
 Classes:  [-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
- 6    3    0  0
- 0  121   29  0
- 0   26  136  7
- 0    0    4  1
+ 4    5    0  0
+ 1  124   25  0
+ 0    4  165  0
+ 0    0    5  0

-Accuracy: 0.7927927927927928
-Kappa:    0.6153446948136739
+Accuracy: 0.8798798798798799
+Kappa:    0.7701030394035107

 Fold 2
 Classes:  [-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
- 9    1    0  0
- 5  123   15  0
- 0   16  156  2
- 0    0    6  0
+ 8    2    0  0
+ 0  128   15  0
+ 0   10  164  0
+ 0    0    4  2

-Accuracy: 0.8648648648648649
-Kappa:    0.7499123817153157
+Accuracy: 0.9069069069069069
+Kappa:    0.8248409264443879

 Fold 3
 Classes:  [-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
- 7    3    0  0
- 3  121   33  0
- 0   11  144  4
+ 2    8    0  0
+ 0  141   16  0
+ 0    4  155  0
  0    0    6  1

-Accuracy: 0.8198198198198198
-Kappa:    0.6695445072938374
+Accuracy: 0.8978978978978979
+Kappa:    0.807114382091383

-Mean Accuracy: 0.8258258258258259
+Mean Accuracy: 0.8948948948948949

 ##### nfoldCV Adaboosted Stumps #####
 Testing nfoldCV_stumps
@@ -179,37 +179,37 @@ Mean Accuracy: 0.9629629629629629
 Fold 1
 Classes:  Int32[-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
- 14    9    0   0
-  2  130    7   0
-  0   10  140   2
-  0    1    2  16
+ 17    6    0   0
+  0  135    4   0
+  0    8  144   0
+  0    0    4  15

-Accuracy: 0.9009009009009009
-Kappa:    0.83520043190714
+Accuracy: 0.933933933933934
+Kappa:    0.8896653513660049

 Fold 2
 Classes:  Int32[-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
- 10   13    0   0
-  1  139   16   0
-  0   13  120   1
-  0    0   10  10
+ 13   10    0   0
+  0  143   13   0
+  0    7  125   2
+  0    0    1  19

-Accuracy: 0.8378378378378378
-Kappa:    0.7238297088094361
+Accuracy: 0.9009009009009009
+Kappa:    0.8349603508350355

 Fold 3
 Classes:  Int32[-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
  16    1    0   0
-  1  126   10   0
+  0  127   10   0
   0    1  150   0
-  0    0    7  21
+  0    0   10  18

-Accuracy: 0.93993993993994
-Kappa:    0.9009797945256397
+Accuracy: 0.933933933933934
+Kappa:    0.8902800658978584

-Mean Accuracy: 0.8928928928928929
+Mean Accuracy: 0.9229229229229229

 ##### nfoldCV Adaboosted Stumps #####

@@ -265,13 +265,13 @@ Feature 3 < 2.45 ?
             └─ Iris-virginica : 1/1
         └─ Feature 4 < 1.55 ?
             ├─ Iris-virginica : 3/3
-            └─ Feature 1 < 6.95 ?
+            └─ Feature 3 < 5.45 ?
                 ├─ Iris-versicolor : 2/2
                 └─ Iris-virginica : 1/1
     └─ Feature 3 < 4.85 ?
-        ├─ Feature 2 < 3.1 ?
-            ├─ Iris-virginica : 2/2
-            └─ Iris-versicolor : 1/1
+        ├─ Feature 1 < 5.95 ?
+            ├─ Iris-versicolor : 1/1
+            └─ Iris-virginica : 2/2
         └─ Iris-virginica : 43/43

 ##### nfoldCV Classification Tree #####
@@ -314,33 +314,33 @@ Fold 1
 Classes:  ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
 Matrix:   3×3 Matrix{Int64}:
  20   0  0
-  0  20  1
+  0  18  3
   0   1  8

-Accuracy: 0.96
-Kappa:    0.9366286438529784
+Accuracy: 0.92
+Kappa:    0.8751560549313357

 Fold 2
 Classes:  ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
 Matrix:   3×3 Matrix{Int64}:
  15   0   0
   0  15   1
-  0   3  16
+  0   2  17

-Accuracy: 0.92
-Kappa:    0.8798076923076925
+Accuracy: 0.94
+Kappa:    0.9096929560505719

 Fold 3
 Classes:  ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
 Matrix:   3×3 Matrix{Int64}:
  15   0   0
-  0  12   1
-  0   7  15
+  0  13   0
+  0   3  19

-Accuracy: 0.84
-Kappa:    0.7613365155131264
+Accuracy: 0.94
+Kappa:    0.9090357792601576

-Mean Accuracy: 0.9066666666666666
+Mean Accuracy: 0.9333333333333332

 ##### nfoldCV Classification Adaboosted Stumps #####

@@ -385,7 +385,7 @@ Mean Accuracy: 0.8109892809975735

 ##### 3 foldCV Classification Forest #####

-Mean Accuracy: 0.8270217144261188
+Mean Accuracy: 0.8429005804846587

 ##### nfoldCV Classification Adaboosted Stumps #####

@@ -422,21 +422,21 @@ Mean Coeff of Determination: 0.821479058935842
 ##### nfoldCV Regression Forest #####

 Fold 1
-Mean Squared Error:     2.0183096134238294
-Correlation Coeff:      0.8903914722230327
-Coeff of Determination: 0.7924911697044006
+Mean Squared Error:     1.3577742526795888
+Correlation Coeff:      0.9396271935146402
+Coeff of Determination: 0.8604029108789377

 Fold 2
-Mean Squared Error:     1.9714838724549328
-Correlation Coeff:      0.910241766877058
-Coeff of Determination: 0.8011434924520122
+Mean Squared Error:     1.3034832328733625
+Correlation Coeff:      0.9529278684745566
+Coeff of Determination: 0.8685223212027657

 Fold 3
-Mean Squared Error:     1.6739772387561769
-Correlation Coeff:      0.9029059136519314
-Coeff of Determination: 0.813068012307753
+Mean Squared Error:     1.1485186853278506
+Correlation Coeff:      0.9420191589030741
+Coeff of Determination: 0.8717456392002396

-Mean Coeff of Determination: 0.8022342248213886
+Mean Coeff of Determination: 0.8668902904273144
 ==================================================
 TEST: regression/digits.jl

@@ -447,7 +447,7 @@ Mean Coeff of Determination: 0.6349826429860214

 ##### 3 foldCV Regression Forest #####

-Mean Coeff of Determination: 0.5825527898815513
+Mean Coeff of Determination: 0.6477805012747754
 ==================================================
 TEST: regression/scikitlearn.jl

@@ -496,5 +496,5 @@ TEST: miscellaneous/feature_importance_test.jl

 ==================================================
 Test Summary: | Pass  Total   Time
-Test Suites   | 9658   9658  53.0s
+Test Suites   | 9612   9612  53.6s
      Testing DecisionTree tests passed

What do you think @dhanak?

dhanak commented 1 year ago

I think it's safe to assume that Random.seed!(rng, a) is not correlated with Random.seed!(rng, b) when a != b. That means that it should be safe to drop the shared_seed. What do you think @dhanak?

I agree on the assumption, that is why using shared_seed + i is good enough. I disagree on the conclusion, however. The role of shared_seed is not to disconnect the various copies of rng (i takes care of that), but to make the seeds depend on the current state of rng, and thus make them deterministically different for every unique state of rng.

In your version, every tree with a specific index draws the same sequence of numbers for each invocation, given a specific class of rng, irrespective of the specific state in which rng is. I.e., the 1st tree always uses one set of numbers, the 2nd tree always uses another, etc. They are different from one another, but not different upon each invocation.

rikhuijzer commented 1 year ago

Now I get it. Thanks, David :smile:

@ablaom Can you merge this and create a release? I don't yet understand how to create releases in the MLJ-style, unfortunately.

ablaom commented 1 year ago

Sure, I'll take care of it. FYI: new release instructions.

ablaom commented 1 year ago

Thanks @dhanak for this valuable contribution. Thank you @rikhuijzer for your generous engagement and review. 🙏🏾