Closed dhanak closed 1 year ago
Thanks @dhanak for this valuable contribution. @rikhuijzer I really think you are in the best position to review this PR, if you don't mind?
I think it's safe to assume that Random.seed!(rng, a)
is not correlated with Random.seed!(rng, b)
when a != b
. That means that it should be safe to drop the shared_seed
.
Below are the accuracy comparisons of what is currently the dev
branch of DecisionTree
versus one where
_rng = Random.seed!(copy(rng), i)
diff --git a/tree-old.txt b/tree-new.txt
index 5086247..f7fe469 100644
--- a/tree-old.txt
+++ b/tree-new.txt
@@ -48,64 +48,64 @@ Mean Accuracy: 0.8688688688688688
##### nfoldCV Classification Forest #####
Testing nfoldCV_forest
-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999
-Mean Accuracy: 0.8448448448448449
+Mean Accuracy: 0.908908908908909
-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999
-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999
-Mean Accuracy: 0.8608608608608609
+Mean Accuracy: 0.913913913913914
-Mean Accuracy: 0.8608608608608609
+Mean Accuracy: 0.913913913913914
-Mean Accuracy: 0.8358358358358359
+Mean Accuracy: 0.9059059059059059
-Mean Accuracy: 0.8358358358358359
+Mean Accuracy: 0.9059059059059059
-Mean Accuracy: 0.8818818818818818
+Mean Accuracy: 0.9089089089089089
-Mean Accuracy: 0.8818818818818818
+Mean Accuracy: 0.9089089089089089
-Mean Accuracy: 0.8788788788788789
+Mean Accuracy: 0.9019019019019018
-Mean Accuracy: 0.8788788788788789
+Mean Accuracy: 0.9019019019019018
Fold 1
Classes: [-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
- 6 3 0 0
- 0 121 29 0
- 0 26 136 7
- 0 0 4 1
+ 4 5 0 0
+ 1 124 25 0
+ 0 4 165 0
+ 0 0 5 0
-Accuracy: 0.7927927927927928
-Kappa: 0.6153446948136739
+Accuracy: 0.8798798798798799
+Kappa: 0.7701030394035107
Fold 2
Classes: [-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
- 9 1 0 0
- 5 123 15 0
- 0 16 156 2
- 0 0 6 0
+ 8 2 0 0
+ 0 128 15 0
+ 0 10 164 0
+ 0 0 4 2
-Accuracy: 0.8648648648648649
-Kappa: 0.7499123817153157
+Accuracy: 0.9069069069069069
+Kappa: 0.8248409264443879
Fold 3
Classes: [-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
- 7 3 0 0
- 3 121 33 0
- 0 11 144 4
+ 2 8 0 0
+ 0 141 16 0
+ 0 4 155 0
0 0 6 1
-Accuracy: 0.8198198198198198
-Kappa: 0.6695445072938374
+Accuracy: 0.8978978978978979
+Kappa: 0.807114382091383
-Mean Accuracy: 0.8258258258258259
+Mean Accuracy: 0.8948948948948949
##### nfoldCV Adaboosted Stumps #####
Testing nfoldCV_stumps
@@ -179,37 +179,37 @@ Mean Accuracy: 0.9629629629629629
Fold 1
Classes: Int32[-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
- 14 9 0 0
- 2 130 7 0
- 0 10 140 2
- 0 1 2 16
+ 17 6 0 0
+ 0 135 4 0
+ 0 8 144 0
+ 0 0 4 15
-Accuracy: 0.9009009009009009
-Kappa: 0.83520043190714
+Accuracy: 0.933933933933934
+Kappa: 0.8896653513660049
Fold 2
Classes: Int32[-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
- 10 13 0 0
- 1 139 16 0
- 0 13 120 1
- 0 0 10 10
+ 13 10 0 0
+ 0 143 13 0
+ 0 7 125 2
+ 0 0 1 19
-Accuracy: 0.8378378378378378
-Kappa: 0.7238297088094361
+Accuracy: 0.9009009009009009
+Kappa: 0.8349603508350355
Fold 3
Classes: Int32[-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
16 1 0 0
- 1 126 10 0
+ 0 127 10 0
0 1 150 0
- 0 0 7 21
+ 0 0 10 18
-Accuracy: 0.93993993993994
-Kappa: 0.9009797945256397
+Accuracy: 0.933933933933934
+Kappa: 0.8902800658978584
-Mean Accuracy: 0.8928928928928929
+Mean Accuracy: 0.9229229229229229
##### nfoldCV Adaboosted Stumps #####
@@ -265,13 +265,13 @@ Feature 3 < 2.45 ?
└─ Iris-virginica : 1/1
└─ Feature 4 < 1.55 ?
├─ Iris-virginica : 3/3
- └─ Feature 1 < 6.95 ?
+ └─ Feature 3 < 5.45 ?
├─ Iris-versicolor : 2/2
└─ Iris-virginica : 1/1
└─ Feature 3 < 4.85 ?
- ├─ Feature 2 < 3.1 ?
- ├─ Iris-virginica : 2/2
- └─ Iris-versicolor : 1/1
+ ├─ Feature 1 < 5.95 ?
+ ├─ Iris-versicolor : 1/1
+ └─ Iris-virginica : 2/2
└─ Iris-virginica : 43/43
##### nfoldCV Classification Tree #####
@@ -314,33 +314,33 @@ Fold 1
Classes: ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
Matrix: 3×3 Matrix{Int64}:
20 0 0
- 0 20 1
+ 0 18 3
0 1 8
-Accuracy: 0.96
-Kappa: 0.9366286438529784
+Accuracy: 0.92
+Kappa: 0.8751560549313357
Fold 2
Classes: ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
Matrix: 3×3 Matrix{Int64}:
15 0 0
0 15 1
- 0 3 16
+ 0 2 17
-Accuracy: 0.92
-Kappa: 0.8798076923076925
+Accuracy: 0.94
+Kappa: 0.9096929560505719
Fold 3
Classes: ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
Matrix: 3×3 Matrix{Int64}:
15 0 0
- 0 12 1
- 0 7 15
+ 0 13 0
+ 0 3 19
-Accuracy: 0.84
-Kappa: 0.7613365155131264
+Accuracy: 0.94
+Kappa: 0.9090357792601576
-Mean Accuracy: 0.9066666666666666
+Mean Accuracy: 0.9333333333333332
##### nfoldCV Classification Adaboosted Stumps #####
@@ -385,7 +385,7 @@ Mean Accuracy: 0.8109892809975735
##### 3 foldCV Classification Forest #####
-Mean Accuracy: 0.8270217144261188
+Mean Accuracy: 0.8429005804846587
##### nfoldCV Classification Adaboosted Stumps #####
@@ -422,21 +422,21 @@ Mean Coeff of Determination: 0.821479058935842
##### nfoldCV Regression Forest #####
Fold 1
-Mean Squared Error: 2.0183096134238294
-Correlation Coeff: 0.8903914722230327
-Coeff of Determination: 0.7924911697044006
+Mean Squared Error: 1.3577742526795888
+Correlation Coeff: 0.9396271935146402
+Coeff of Determination: 0.8604029108789377
Fold 2
-Mean Squared Error: 1.9714838724549328
-Correlation Coeff: 0.910241766877058
-Coeff of Determination: 0.8011434924520122
+Mean Squared Error: 1.3034832328733625
+Correlation Coeff: 0.9529278684745566
+Coeff of Determination: 0.8685223212027657
Fold 3
-Mean Squared Error: 1.6739772387561769
-Correlation Coeff: 0.9029059136519314
-Coeff of Determination: 0.813068012307753
+Mean Squared Error: 1.1485186853278506
+Correlation Coeff: 0.9420191589030741
+Coeff of Determination: 0.8717456392002396
-Mean Coeff of Determination: 0.8022342248213886
+Mean Coeff of Determination: 0.8668902904273144
==================================================
TEST: regression/digits.jl
@@ -447,7 +447,7 @@ Mean Coeff of Determination: 0.6349826429860214
##### 3 foldCV Regression Forest #####
-Mean Coeff of Determination: 0.5825527898815513
+Mean Coeff of Determination: 0.6477805012747754
==================================================
TEST: regression/scikitlearn.jl
@@ -496,5 +496,5 @@ TEST: miscellaneous/feature_importance_test.jl
==================================================
Test Summary: | Pass Total Time
-Test Suites | 9658 9658 53.0s
+Test Suites | 9612 9612 53.6s
Testing DecisionTree tests passed
What do you think @dhanak?
I think it's safe to assume that
Random.seed!(rng, a)
is not correlated withRandom.seed!(rng, b)
whena != b
. That means that it should be safe to drop theshared_seed
. What do you think @dhanak?
I agree on the assumption, that is why using shared_seed + i
is good enough. I disagree on the conclusion, however. The role of shared_seed
is not to disconnect the various copies of rng
(i
takes care of that), but to make the seeds depend on the current state of rng
, and thus make them deterministically different for every unique state of rng
.
In your version, every tree with a specific index draws the same sequence of numbers for each invocation, given a specific class of rng
, irrespective of the specific state in which rng
is. I.e., the 1st tree always uses one set of numbers, the 2nd tree always uses another, etc. They are different from one another, but not different upon each invocation.
Now I get it. Thanks, David :smile:
@ablaom Can you merge this and create a release? I don't yet understand how to create releases in the MLJ-style, unfortunately.
Sure, I'll take care of it. FYI: new release instructions.
Thanks @dhanak for this valuable contribution. Thank you @rikhuijzer for your generous engagement and review. 🙏🏾
Fixes #194.
Using
rand(_rng, i)
didn't really put all copies ofrng
into a unique state, the states were still interlocked (all the generators produced same sequence of random numbers with some offset). Callingseed!
with a deterministic, pseudo-random seed for each thread produces much better results, which is also visible in the classification and regression accuracies produced by the tests.Here is the diff output of the unit tests before and after the change: