RPegoud / jym

JAX implementation of RL algorithms and vectorized environments
MIT License
32 stars 1 forks source link

added greedy-bandits, refactored repository #6

Closed RPegoud closed 10 months ago

RPegoud commented 10 months ago

Added the simple bandit agent:

class SimpleBandit:
    @staticmethod
    def __call__(action, reward, pulls, q_values):
        q_value = q_values[action]

        pull_update = pulls[action] + 1
        # incremental mean update
        # new_estimate = old_estimate + step_size(target - olde_estimate)
        q_update = q_value + (1 / pull_update) * (reward - q_value)
        return q_values.at[action].set(q_update), pulls.at[action].set(pull_update)

And the K-armed bandit "casino" environment:

class K_armed_bandits(BanditsBaseEnv):
    def __init__(self, K: int, SEED: int) -> None:
        super(K_armed_bandits, self).__init__()
        self.K = K
        self.actions = jnp.arange(K)
        self.init_key = random.PRNGKey(SEED)
        self.bandits_q = random.normal(self.init_key, shape=(K,))

    def render(self):
        """
        Renders a violin plot of the reward distribution of each bandit
        """
        import pandas as pd
        import plotly.express as px

        samples = random.normal(self.init_key, (self.K, 1000))
        samples_df = pd.DataFrame(samples).T
        shifted = samples_df + pd.Series(self.bandits_q)
        melted = shifted.melt()
        melted["mean"] = melted.variable.apply(lambda x: self.bandits_q[x])
        fig = px.violin(
            melted,
            x="variable",
            y="value",
            color="variable",
            title=f"{self.K}-armed Bandits testbed",
        )
        fig.update_traces(meanline_visible=True)
        fig.show()

    def __repr__(self) -> str:
        return str(self.__dict__)

    def get_reward(self, key, action):
        key, subkey = random.split(key)
        return random.normal(subkey) + self.bandits_q[action], subkey

Refactored the whole repo to differenciate tabular and bandit envs/agents/policies