JuliaAI / DecisionTree.jl

Julia implementation of Decision Tree (CART) and Random Forest algorithms
Other
356 stars 102 forks source link

RNG “shuffling” introduced in #174 is fundamentally flawed #194

Closed dhanak closed 1 year ago

dhanak commented 1 year ago

Merge #174 introduced a change in the initialization of rngs per tree (both for classification or regression forests). Namely, instead of using the same rng for all trees, it creates a copy for all of them, and then pushes them into a different state by pulling a different number of random numbers from each of them. (See here for details.) The number of random numbers being drawn is equal to the index of the tree being built.

While I fully respect the intent, and appreciate that before this change, the code was not thread-safe, I must point out that this logic is fundamentally flawed. The problem is that drawing 1, 2, 3, etc. random numbers from the rng copies does not break the connection between them, and the resulting forest will not be random at all. Many trees, in fact, will be identical or very similar, and not by pure chance. I couldn't yet fully figure out why or where, but there is some implicit mechanism in the tree building which unintentionally resynchronizes the state of the rng copies. I have a hunch that it is related to the hypergeometric sampling in _split!() in tree.jl, but there could be something else, too.

The negative effect, however, is clearly noticeable, because the classification/regression accuracy of the resulting forests is suboptimal, and introducing subtle changes in certain hyperparameters (such as n_subfeatures) end up causing major changes in the prediction accuracy. I'm still struggling to find a suitable and small enough example to demonstrate the effect clearly, I'll let you know if I find one.

I have begun working on a fix in a forked repo, the crux of which is to use pre-generated pseudo-random seeds to move all the generators into a unique state. Now, I understand that seed! is not necessarily implemented for all rng classes, so I added an applicability test, and fall back to the current behavior when seed! is not available.

To demonstrate the effect of this change, I ran the unit tests with the current version of DecisionTree and with my proposed change, and compared the results. The listed accuracies and confusion matrices improved noticeably in every single case, which sort of proves my point indirectly.

Here's the diff. The left-hand side is produced by the official package, while the right-hand side is the output of my fork.

89c92
< Mean Accuracy: 0.8748748748748749
---
> Mean Accuracy: 0.9049049049049049
91c94
< Mean Accuracy: 0.8448448448448449
---
> Mean Accuracy: 0.8868868868868868
93c96
< Mean Accuracy: 0.8748748748748749
---
> Mean Accuracy: 0.9049049049049049
95c98
< Mean Accuracy: 0.8748748748748749
---
> Mean Accuracy: 0.9049049049049049
97c100
< Mean Accuracy: 0.8608608608608609
---
> Mean Accuracy: 0.8948948948948949
99c102
< Mean Accuracy: 0.8608608608608609
---
> Mean Accuracy: 0.8948948948948949
101c104
< Mean Accuracy: 0.8358358358358359
---
> Mean Accuracy: 0.9029029029029029
103c106
< Mean Accuracy: 0.8358358358358359
---
> Mean Accuracy: 0.9029029029029029
105c108
< Mean Accuracy: 0.8818818818818818
---
> Mean Accuracy: 0.9009009009009009
107c110
< Mean Accuracy: 0.8818818818818818
---
> Mean Accuracy: 0.9009009009009009
109c112
< Mean Accuracy: 0.8788788788788789
---
> Mean Accuracy: 0.896896896896897
111c114
< Mean Accuracy: 0.8788788788788789
---
> Mean Accuracy: 0.896896896896897
117,119c120,122
<  0  121   29  0
<  0   26  136  7
<  0    0    4  1
---
>  1  134   15  0
>  0    5  163  1
>  0    0    5  0
121,122c124,125
< Accuracy: 0.7927927927927928
< Kappa:    0.6153446948136739
---
> Accuracy: 0.9099099099099099
> Kappa:    0.8295047274464963
127,130c130,133
<  9    1    0  0
<  5  123   15  0
<  0   16  156  2
<  0    0    6  0
---
>  6    4    0  0
>  0  135    8  0
>  0   12  162  0
>  0    0    4  2
132,133c135,136
< Accuracy: 0.8648648648648649
< Kappa:    0.7499123817153157
---
> Accuracy: 0.9159159159159159
> Kappa:    0.8418266947139852
138,141c141,144
<  7    3    0  0
<  3  121   33  0
<  0   11  144  4
<  0    0    6  1
---
>  4    6    0  0
>  0  137   20  0
>  0    7  152  0
>  0    0    7  0
143,144c146,147
< Accuracy: 0.8198198198198198
< Kappa:    0.6695445072938374
---
> Accuracy: 0.8798798798798799
> Kappa:    0.7736156905401272
146c149
< Mean Accuracy: 0.8258258258258259
---
> Mean Accuracy: 0.901901901901902
221,223c224,226
<   2  130    7   0
<   0   10  140   2
<   0    1    2  16
---
>   0  133    6   0
>   0    9  143   0
>   0    0    3  16
225,226c228,229
< Accuracy: 0.9009009009009009
< Kappa:    0.83520043190714
---
> Accuracy: 0.918918918918919
> Kappa:    0.864122714220946
231,234c234,237
<  10   13    0   0
<   1  139   16   0
<   0   13  120   1
<   0    0   10  10
---
>  16    7    0   0
>   0  146   10   0
>   0    7  127   0
>   0    0    2  18
236,237c239,240
< Accuracy: 0.8378378378378378
< Kappa:    0.7238297088094361
---
> Accuracy: 0.9219219219219219
> Kappa:    0.8699511828764551
243,245c246,248
<   1  126   10   0
<   0    1  150   0
<   0    0    7  21
---
>   0  131    6   0
>   0    1  148   2
>   0    0    9  19
247,248c250,251
< Accuracy: 0.93993993993994
< Kappa:    0.9009797945256397
---
> Accuracy: 0.9429429429429429
> Kappa:    0.9058412084232457
250c253
< Mean Accuracy: 0.8928928928928929
---
> Mean Accuracy: 0.9279279279279279
306c309
<             └─ Feature 1 < 6.95 ?
---
>             └─ Feature 3 < 5.45 ?
355,356c358,359
<   0  20  1
<   0   1  8
---
>   0  19  2
>   0   0  9
359c362
< Kappa:    0.9366286438529784
---
> Kappa:    0.9375780274656679
365,366c368,369
<   0  15   1
<   0   3  16
---
>   0  16   0
>   0   2  17
368,369c371,372
< Accuracy: 0.92
< Kappa:    0.8798076923076925
---
> Accuracy: 0.96
> Kappa:    0.9399038461538461
375,376c378,379
<   0  12   1
<   0   7  15
---
>   0  13   0
>   0   1  21
378,379c381,382
< Accuracy: 0.84
< Kappa:    0.7613365155131264
---
> Accuracy: 0.98
> Kappa:    0.9693439607602697
381c384
< Mean Accuracy: 0.9066666666666666
---
> Mean Accuracy: 0.9666666666666667
426c429
< Mean Accuracy: 0.8270217144261188
---
> Mean Accuracy: 0.8432691421726712
463,465c466,468
< Mean Squared Error:     2.0183096134238294
< Correlation Coeff:      0.8903914722230327
< Coeff of Determination: 0.7924911697044006
---
> Mean Squared Error:     1.2449684223406625
> Correlation Coeff:      0.9546841050200022
> Coeff of Determination: 0.8720008370585812
468,470c471,473
< Mean Squared Error:     1.9714838724549328
< Correlation Coeff:      0.910241766877058
< Coeff of Determination: 0.8011434924520122
---
> Mean Squared Error:     1.364165257898544
> Correlation Coeff:      0.9504993092312798
> Coeff of Determination: 0.8624015429727003
473,475c476,478
< Mean Squared Error:     1.6739772387561769
< Correlation Coeff:      0.9029059136519314
< Coeff of Determination: 0.813068012307753
---
> Mean Squared Error:     1.0596279001140356
> Correlation Coeff:      0.9481698135943679
> Coeff of Determination: 0.8816720173987207
477c480
< Mean Coeff of Determination: 0.8022342248213886
---
> Mean Coeff of Determination: 0.872024799143334
488c491
< Mean Coeff of Determination: 0.5825527898815513
---
> Mean Coeff of Determination: 0.6096699730256437
dhanak commented 1 year ago

For an example of the “butterfly effect” mentioned above, have a look at the following heatmap of the EER of a regression forest. The number of genuine training samples is along the horizontal axis, while the number of used features is on the vertical axis. This surface should be fairly smooth, but notice the significant gap/jump as the number of features increases from 110 to 111. It causes an approximately 4% accuracy drop in EER values.

heatmap

It took me a while to figure out the following:

  1. As the feature count increases from 110 to 111, the rounded square root increases from 10 to 11. This value is used (by default) for n_subfeatures.
  2. Feature sub-sampling begins by drawing a number from a hypergeometric distribution. For n_subfeatures = 10, this drawing uses a different algorithm than for 11. (There is an explicit if branch in utils.jl, in the implementation of the hypergeometric distribution, inherited from numpy.)
  3. The implementation called for 10, as it happens in this particular example, doesn't use the rng at all. The more complex implementation called for 11, though, does draw a few random values. The exact number depends on the random values themselves. I believe this is the key to the automatic resynchronization of the rng copies.

With the proposed fix, not only the wrinkle in the heatmap goes away, but also all accuracy values improve noticeably.

ablaom commented 1 year ago

Wow. Fine work identifying and diagnosing a hairy issue. This represents a lot of work!

the crux of which is to use pre-generated pseudo-random seeds to move all the generators into a unique state.

For my part, I'm inclined to support your kind offer for a fix, but could you please describe this in a bit more detail and/or point to the relevant code in your fork (ideally both).

Note to self: breaking change.

@rikhuijzer @bensadeghi

dhanak commented 1 year ago

the crux of which is to use pre-generated pseudo-random seeds to move all the generators into a unique state.

For my part, I'm inclined to support your kind offer for a fix, but could you please describe this in a bit more detail and/or point to the relevant code in your fork (ideally both).

Sure, here's the diff: https://github.com/JuliaAI/DecisionTree.jl/compare/dev...dhanak:DecisionTree.jl:dev. First I check if seed! is applicable on the specific rng instance, and if yes, generate a vector of random uints that will act as seeds, one per tree. Then for each tree, after cloning the rng instance, I invoke seed! with the clone and the corresponding pre-generated seed. If, on the other hand, seed! is not applicable, then I fall back to the current “draw a variable number of values from rng” approach, which I'm still unhappy with.

I contemplated drawing not one, but 1000, or even 1,000,000 random values per tree (i.e., rand(_rng, 1_000_000i)), which should, in the vast majority of cases, put enough distance between the rng copies, and it's not too expensive for the built-in Xoshiro and MersenneTwister generators, either. But I felt uneasy about it, because the built-in generators do have a seed! implementation, so they will use the other branch anyhow. On the other hand, this fallback case is going be used for other, 3rd party, or yet unknown generators, with which burning millions of random values could be a significant overhead in computation costs.

This is the main reason why I haven't issued a pull request yet: I'm still uncertain how to handle the fallback case properly. Any ideas welcome!

rikhuijzer commented 1 year ago

Impressive! Well spotted!

The problem is that drawing 1, 2, 3, etc. random numbers from the rng copies does not break the connection between them, and the resulting forest will not be random at all. Many trees, in fact, will be identical or very similar, and not by pure chance.

Ouch.

The changes such as

     if rng isa Random.AbstractRNG
+        seeds = applicable(Random.seed!, rng, 0) ? rand(rng, UInt, n_trees) : nothing
         Threads.@threads for i in 1:n_trees
             # The Mersenne Twister (Julia's default) is not thread-safe.
             _rng = copy(rng)
-            # Take some elements from the ring to have different states for each tree.  This
-            # is the only way given that only a `copy` can be expected to exist for RNGs.
-            rand(_rng, i)
+            if seeds !== nothing
+                 # Seed the ring for each tree with a pseudo-random seed to put it
+                 # into a predictable, but different state from all the others.
+                Random.seed!(_rng, seeds[i])
+             else
+                 # Take some elements from the ring to have different states for each tree.
+                 # This is the only way given that only a `copy` can be expected to exist for RNGs.
+                 rand(_rng, i)
+             end

look good to me. Why is the applicable needed? Is that for Julia below 1.6?

rikhuijzer commented 1 year ago

Random.seed!(_rng, seeds[i]) could probably be replaced by Random.seed!(_rng, i). That way the seeds vector can be omitted. It doesn't matter too much for performance, but it should aid readability.

By the way, I wonder whether we should guarantee that values in the current seeds vector are all distinct. In the current implementation, there may exist trees which end up in exactly the same state. It shouldn't matter too much since the number of trees in a random forest is typically much smaller than the possible values for UInts, but still good to double check.

dhanak commented 1 year ago

look good to me. Why is the applicable needed? Is that for Julia below 1.6?

I added that part because it is not required to implement Random.seed! for all rng classes. (From the Julia docs: “Some RNGs don't accept a seed, like RandomDevice”.)

Random.seed!(_rng, seeds[i]) could probably be replaced by Random.seed!(_rng, i). That way the seeds vector can be omitted. It doesn't matter too much for performance, but it should aid readability.

II think that would be too deterministic, i.e., the random numbers generated for the trees would not depend on the state of the rng passed to build_forest, only on its type. But we can employ the same trick as on the other branch (when the rng parameter is a number), i.e.: generate a single initial seed, and then offset that seed with i. That would indeed save on the memory, and still use a different yet deterministic seed for each tree, which depends on the state of the passed rng. Something like this, perhaps?

        # Not all rngs are expected to implement `seed!`
        spread = if applicable(Random.seed!, rng, 0)
            seed0 = rand(rng, UInt)
            # Seed each ring with a different (but deterministic) seed.
            (rng, i) -> Random.seed!(rng, seed0 + i)
        else
            # Take some elements from the ring to have different states for each tree.
            (rng, i) -> rand(rng, i)
        end
        Threads.@threads for i in 1:n_trees
            # The Mersenne Twister (Julia's default) is not thread-safe.
            _rng = copy(rng)
            spread(_rng, i)
            inds = rand(_rng, 1:t_samples, n_samples)
            forest[i] = build_tree(
                labels[inds],
                features[inds,:],
                n_subfeatures,
                max_depth,
                min_samples_leaf,
                min_samples_split,
                min_purity_increase,
                rng = _rng,
                impurity_importance = impurity_importance)
        end
rikhuijzer commented 1 year ago

I added that part because it is not required to implement Random.seed! for all rng classes. (From the Julia docs: “Some RNGs don't accept a seed, like RandomDevice”.)

:+1:

II think that would be too deterministic, i.e., the random numbers generated for the trees would not depend on the state of the rng passed to build_forest, only on its type.

Better be safe than sorry, I guess :+1:.

As a sidenote,

        # Not all rngs are expected to implement `seed!`
        spread = if applicable(Random.seed!, rng, 0)
            seed0 = rand(rng, UInt)
            # Seed each ring with a different (but deterministic) seed.
            (rng, i) -> Random.seed!(rng, seed0 + i)
        else
            # Take some elements from the ring to have different states for each tree.
            (rng, i) -> rand(rng, i)
        end

can be extracted into a separate function for readability. Something like:

function spread!(rng, i)
    if applicable(Random.seed!, rng, 0)
        seed0 = rand(rng, UInt)
        # Seed each ring with a different (but deterministic) seed.
        return Random.seed!(rng, seed0 + i)
    else
        # Take some elements from the ring to have different states for each tree.
        return rand(rng, i)
    end
end

I've added the exclamation mark to the end of the name since that's a convention for mutation of arguments in Julia.

@dhanak Your bug find helped me a lot for a paper that I'm currently working on! Can I mention you in the acknowledgements? I assume yes and let me know if you don't want that. You can respond here or to t.h.huijzer@rug.nl.

rikhuijzer commented 1 year ago

@dhanak do you want to open a pull request here? I can also do it if you want and make you co-author.

dhanak commented 1 year ago

can be extracted into a separate function for readability

I deliberately didn't want to make the if part of the function, because I didn't want to call applicable for every single tree. Also, this way, you mixed up the shared rng instance with the treewise copies. (I picked shadowing variable names out of laziness, I'll grant you that.) Here's another alternative, using multiple dispatch and proper functions. WDYT?

In utils.jl:

    using Random
    ...
    spread!(rng::Random.AbstractRNG, seed::Nothing, i::Integer) = rand(rng, i)
    spread!(rng::Random.AbstractRNG, seed::Integer, i::Integer) = Random.seed!(rng, seed + i)

And then:

        seed = applicable(Random.seed!, rng, 0) ? rand(rng, UInt) : nothing
        Threads.@threads for i in 1:n_trees
            # The Mersenne Twister (Julia's default) is not thread-safe.
            _rng = copy(rng)
            util.spread!(_rng, seed, i)
            inds = rand(_rng, 1:t_samples, n_samples)
            forest[i] = build_tree(
                labels[inds],
                features[inds,:],
                n_subfeatures,
                max_depth,
                min_samples_leaf,
                min_samples_split,
                min_purity_increase,
                loss = loss,
                rng = _rng,
                impurity_importance = impurity_importance)
        end

Your bug find helped me a lot for a paper that I'm currently working on! Can I mention you in the acknowledgements?

Absolutely, I'd be honored!

@dhanak do you want to open a pull request here? I can also do it if you want and make you co-author.

Yeah, I will, as soon as I find an implementation that we are moderately satisfied with, both syntactically and semantically. That being said, I'm still uneasy about the fallback solution.

rikhuijzer commented 1 year ago

Yes the multiple dispatch looks good, yet the dispatch on Nothing vs. Integer to distinguish between whether applicable also makes me uneasy.

Maybe that @ablaom knows whether we even need to handle the missing of seed!. Since accuracy will be reduced when using rand, maybe we should throw a warning?

    spread! = if applicable(Random.seed!, rng, 0)
        Random.seed!
    else
        @warn "The used RNG does not implement `Random.seed!`. Falling back to `rand` which will reduce accuracy of the fitted model."
        rand
    end
    Threads.@threads for i in 1:n_trees
        # The Mersenne Twister (Julia's default) is not thread-safe.
        _rng = copy(rng)
        seed0 = rand(rng, UInt)
        spread!(_rng, seed0 + i)
        ...
    end
dhanak commented 1 year ago

With this adjustment, we are more or less back to my previous, “inline” suggestion. Please note, however, that you do not want to take seed0 + i random numbers from any rng on the fallback branch, that would be uncomputably many. The warning is a good idea, perhaps with a maxlog=1 attribute to make it appear only once.

The if could be extracted as a separate function, e.g., make_spread!, which returns a function.

function make_spread!(rng::R)::Function where {R <: Random.AbstractRNG}
    # Not all rngs are expected to implement `seed!`
    if applicable(Random.seed!, rng, 0)
        seed0 = rand(rng, UInt)
        # Seed each ring with a different (but deterministic) seed.
        return (_rng::R, i::Integer) -> Random.seed!(_rng, seed0 + i)
    else
        @warn "The used RNG does not implement `Random.seed!`. Falling back to `rand` which will reduce accuracy of the fitted model." maxlog=1
        # Take some elements from the ring to have different states for each tree.
        return rand
    end
end
dhanak commented 1 year ago

What if we tie seeding and copying together, like this?


function replicate(rng::R, n::Integer)::Vector{R} where {R <: Random.AbstractRNG}
    clones = [deepcopy(rng) for _ in 1:n]
    # not all rngs are expected to implement `seed!`
    if applicable(Random.seed!, rng, 0)
        seed_base = rand(rng, UInt)
        # seed each ring with a different (but deterministic) seed
        for i in 1:n
            Random.seed!(clones[i], seed_base + i)
        end
    else
        @warn "The used RNG does not implement `Random.seed!`. Falling back to `rand` which will reduce accuracy of the fitted model." maxlog=1
        # take some elements from the ring to have different states for each clone
        for i in 1:n
            rand(clones[i], i)
        end
    end
    return clones
end

And then call as:

    if rng isa Random.AbstractRNG
        local_rngs = util.replicate(rng, n_trees)
        Threads.@threads for i in 1:n_trees
            _rng = local_rngs[i]
            inds = rand(_rng, 1:t_samples, n_samples)
            forest[i] = build_tree(
    ...

Please also note, that I replaced copy with deepcopy, because that also works for RandomDevice, and presumably other, non-standard RNGs as well.

rikhuijzer commented 1 year ago

I like it because it moves all the complexity that we've been talking about for this whole issue in one separate place.

function replicate(rng::R, n::Integer)::Vector{R} where {R <: Random.AbstractRNG}

can probably be

function replicate(rng::Random.AbstractRNG, n::Integer)

since Julia will likely infer the return type automatically. For example, when replacing the loop by broadcasting and taking two RNGs as example:

julia> @code_warntype deepcopy.([Random.GLOBAL_RNG, Random.GLOBAL_RNG])
MethodInstance for (::var"##dotfunction#314#3")(::Vector{Random._GLOBAL_RNG})
  from (::var"##dotfunction#314#3")(x1) in Main
Arguments
  #self#::Core.Const(var"##dotfunction#314#3"())
  x1::Vector{Random._GLOBAL_RNG}
Body::Vector{Random._GLOBAL_RNG}
1 ─ %1 = Base.broadcasted(Main.deepcopy, x1)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(deepcopy), Tuple{Vector{Random._GLOBAL_RNG}}}
│   %2 = Base.materialize(%1)::Vector{Random._GLOBAL_RNG}
└──      return %2

which shows that the return type for the body is Body::Vector{Random._GLOBAL_RNG}.

ablaom commented 1 year ago

I see some very productive interaction here. Many thanks!

Do either of you know a good use-case non-seeding RNGs? If not, my vote is to remove support, unless you've found a way to reduce complexity.

dhanak commented 1 year ago

I see some very productive interaction here. Many thanks!

Do either of you know a good use-case non-seeding RNGs? If not, my vote is to remove support, unless you've found a way to reduce complexity.

RandomDevice doesn't support seed!. Even though I don't really see using RandomDevice as an rng passed to build_forest as a viable option, it is nonetheless an AbstractRNG instance. Other than that, I'm only familiar with the two other built-in generators, so the short answer is no.

Of course, it is a perfectly valid API design decision to accept only rngs which support seed!, but then, I beleive, this information should be included in the docstrings and whatnot.

dhanak commented 1 year ago

So, @ablaom, what's your position on supporting RandomDevice? Shall I PR the more complicated code which uses seed! only if it is usable, or use it without checking, and add a disclaimer in the docstrings (where applicable)?

ablaom commented 1 year ago

Assuming an error is thrown if one attempts to use a non-seed RNG (which I expect it would) my vote would be for the latter option. @rikhuijzer what do you think?