Closed RPegoud closed 1 year ago
Added the simple bandit agent:
class SimpleBandit: @staticmethod def __call__(action, reward, pulls, q_values): q_value = q_values[action] pull_update = pulls[action] + 1 # incremental mean update # new_estimate = old_estimate + step_size(target - olde_estimate) q_update = q_value + (1 / pull_update) * (reward - q_value) return q_values.at[action].set(q_update), pulls.at[action].set(pull_update)
And the K-armed bandit "casino" environment:
class K_armed_bandits(BanditsBaseEnv): def __init__(self, K: int, SEED: int) -> None: super(K_armed_bandits, self).__init__() self.K = K self.actions = jnp.arange(K) self.init_key = random.PRNGKey(SEED) self.bandits_q = random.normal(self.init_key, shape=(K,)) def render(self): """ Renders a violin plot of the reward distribution of each bandit """ import pandas as pd import plotly.express as px samples = random.normal(self.init_key, (self.K, 1000)) samples_df = pd.DataFrame(samples).T shifted = samples_df + pd.Series(self.bandits_q) melted = shifted.melt() melted["mean"] = melted.variable.apply(lambda x: self.bandits_q[x]) fig = px.violin( melted, x="variable", y="value", color="variable", title=f"{self.K}-armed Bandits testbed", ) fig.update_traces(meanline_visible=True) fig.show() def __repr__(self) -> str: return str(self.__dict__) def get_reward(self, key, action): key, subkey = random.split(key) return random.normal(subkey) + self.bandits_q[action], subkey
Refactored the whole repo to differenciate tabular and bandit envs/agents/policies
Added the simple bandit agent:
And the K-armed bandit "casino" environment:
Refactored the whole repo to differenciate tabular and bandit envs/agents/policies