ucl-pond / pySuStaIn

Subtype and Stage Inference (SuStaIn) algorithm with an example using simulated data.
MIT License
112 stars 62 forks source link

Update and improve RNG handling #26

Closed sea-shunned closed 3 years ago

sea-shunned commented 3 years ago

The Problems

  1. The output of pySuStaIn changes whether use_parallel_startpoints is True or False
  2. The current method of resetting the seed can lead to a lack of randomness

Proposed Solution

The seed given when creating a SuStaIn instance is used to create the global_rng attribute of (via np.random.default_rng()), which is used (directly) as the source of randomness for everything outside of the 3 parallelised functions, and indirectly for those 3 functions (as explained below).

Previously, a seed was set inside the parallelised functions (via np.random.seed()), which leads to the same randomness e.g. the random cluster assignments (in AbstractSustain._find_ml_split_iteration) were not fully random on each subsequent split. The cluster sizes and data points changed so it's not terrible, but this is not ideal behaviour. Due to the way that multiprocessing is handled in Python (objects being pickled etc.), passing in np.random.Generator objects with the same seed results in the same problem, as the state is not maintained after pooling of the processes. This is the cause of the difference between the results when doing it serially, or in parallel.

Therefore, the global_rng is used to spawn the required number (N_startpoints) of np.random.default_rng() objects for the multiprocessing pool, each time with a different seed. Note that np.random.SeedSequence is used here to ensure a good variety of randomness, rather than consecutive numbers which may not guarantee this.

As a result, this approach gives better randomness, that is all fully controlled from the central seed with no modification to the global state (which the use of np.random.seed() does), and the same results are obtained whether or not parallelisation is used.