RPegoud / jym

JAX implementation of RL algorithms and vectorized environments
MIT License
32 stars 1 forks source link

added random initialization #10

Closed RPegoud closed 10 months ago

RPegoud commented 10 months ago

Updated the K-armed bandits environment to take the average value of each bandit as an input to get_reward(), enabling each parallel run to have randomly K intialized values

class K_armed_bandits(BanditsBaseEnv):
    def __init__(self, K: int, SEED: int) -> None:
        super(K_armed_bandits, self).__init__()
        self.K = K
        self.actions = jnp.arange(K)
        self.init_key = random.PRNGKey(SEED)
        # self.bandits_q = random.normal(self.init_key, shape=(K,))

    def __repr__(self) -> str:
        return str(self.__dict__)

    @staticmethod
    @jit
    def get_reward(key, action, bandits_q):
        key, subkey = random.split(key)
        return random.normal(subkey) + bandits_q[action], subkey

    @staticmethod
    @jit
    def get_batched_reward(key, action, bandits_q):
        return vmap(
            K_armed_bandits.get_reward,
            in_axes=(0, 0, None),
        )(key, action, bandits_q)

    @staticmethod
    @jit
    def multi_run_batched_reward(key, action, bandits_q):
        return vmap(
            K_armed_bandits.get_batched_reward,
            in_axes=(1, -1, 1),
            out_axes=(1, 1),
        )(key, action, bandits_q)