choderalab / chiron

Differentiable Markov Chain Monte Carlo
https://github.com/choderalab/chiron/wiki
MIT License
14 stars 1 forks source link

Random number generation with Chiron #16

Closed chrisiacovella closed 7 months ago

chrisiacovella commented 7 months ago

During review of the multistage sampling PR #8, the comment was raised:

For seed, I'm not clear how the seed is supposed to be passed through the call chain, but it would seem that the PRNG seed should be chained to the last move, so would have to enter at run().

This is an import point of discussion to avoid conduction individual moves with the same sequences or parallel states with matching sequences.

Advantageous is that the JAX RNG is not stateful, so it will be easier for us to maintain reproducibility.

The general idea that was discussed with @wiederm was the following:

from jax import random
initial_key = random.PRNGKey(1342)
new_key, *three_other_keys = random.split(initial_key, num=4)
chrisiacovella commented 7 months ago

@wiederm added a wrapper class was added for the jax PRNG in utils.py, primarily useful for allow n number of instances to get unique initial state keys from a single seed. This basically makes initial key generation stateful (to avoid duplicate streams of numbers), while still allowing the rest of the code to still take advantage of not being stateful in the actual functionality of the code.

This is slightly different than the solution above (generating n keys from a single seed), but should provide the same effective functionality (of course with a difference sequence). I'll note that since this stores the "key" (and uses this to generate the next split of keys), but returns the "subkey", each sequence will be distinct.

For example, if we use this wrapper to get a random key, and then split this (i.e., generate a new key and subkey), they will all be unique.

PRNG.set_seed(12345)
for i in range(0,4):
    print(random.split(PRNG.get_random_key()))

[[2716353201 129531434], [1309294470 372939535]] [[3532677114 196943530], [ 197247046 803331856]] [[ 340871517 3565264368], [3786761370 4267842577]] [[3518970598 3766832178], [4245378268 3812516042]]

class PRNG:
    _key: random.PRNGKey
    _seed: int

    def __init__(self) -> None:
        """
        A PRNG class that can be used to generate random numbers in JAX.
        The intended use case is to initialize new PRN streams in the `SamplerState` class.

        Example:
        --------
        from chiron.utils import PRNG
        from chiron.states import SamplerState
        from openmmtools.testsystems import HarmonicOscillator

        ho = HarmonicOscillator()
        PRNG.set_seed(1234)
        sampler_state = [SamplerState(ho.positions, PRNG.get_random_key()) for _ in x0s]

        """

        pass
    @classmethod
    def set_seed(cls, seed: int) -> None:
        cls._seed = seed
        cls._key = random.PRNGKey(seed)

    @classmethod
    def get_random_key(cls) -> int:
        key, subkey = random.split(cls._key)
        cls._key = key
        return subkey
wiederm commented 7 months ago

Can we close that issue now?

chrisiacovella commented 7 months ago

Agree!