Closed RPegoud closed 10 months ago
Updated the K-armed bandits environment to take the average value of each bandit as an input to get_reward(), enabling each parallel run to have randomly K intialized values
class K_armed_bandits(BanditsBaseEnv): def __init__(self, K: int, SEED: int) -> None: super(K_armed_bandits, self).__init__() self.K = K self.actions = jnp.arange(K) self.init_key = random.PRNGKey(SEED) # self.bandits_q = random.normal(self.init_key, shape=(K,)) def __repr__(self) -> str: return str(self.__dict__) @staticmethod @jit def get_reward(key, action, bandits_q): key, subkey = random.split(key) return random.normal(subkey) + bandits_q[action], subkey @staticmethod @jit def get_batched_reward(key, action, bandits_q): return vmap( K_armed_bandits.get_reward, in_axes=(0, 0, None), )(key, action, bandits_q) @staticmethod @jit def multi_run_batched_reward(key, action, bandits_q): return vmap( K_armed_bandits.get_batched_reward, in_axes=(1, -1, 1), out_axes=(1, 1), )(key, action, bandits_q)
Updated the K-armed bandits environment to take the average value of each bandit as an input to get_reward(), enabling each parallel run to have randomly K intialized values