RPegoud / jym

JAX implementation of RL algorithms and vectorized environments
MIT License
32 stars 2 forks source link

Add random initialization of Bandits reward distributions #9

Closed RPegoud closed 1 year ago

RPegoud commented 1 year ago

Currently, the bandits environment is initialized with:

class K_armed_bandits(BanditsBaseEnv):
    def __init__(self, K: int, SEED: int) -> None:
        super(K_armed_bandits, self).__init__()
        self.K = K
        self.actions = jnp.arange(K)
        self.init_key = random.PRNGKey(SEED)
        self.bandits_q = random.normal(self.init_key, shape=(K,))

Therefore, when running n_runs of parallel experiments, we always sample over the same distribution. We want to generate one distribution per run, to increase the statistical significance of the comparison between different values of epsilon.

def get_reward(self, key, action):
        key, subkey = random.split(key)
        return random.normal(subkey) + self.bandits_q[action], subkey

Should become something like

bandits_q = random.normal(key, self.K, n_runs)
def get_reward(self, key, action, bandits_q):
        key, subkey = random.split(key)
        return random.normal(subkey) + bandits_q[action, run], subkey