Closed RPegoud closed 10 months ago
Using the Softmax policy as a proxy for action probabilities, implement the expected Sarsa algorithm.
class Softmax_policy(BasePolicy): def __init__(self, temperature) -> None: self.temperature = temperature @partial(jit, static_argnums=(0)) def call(self, state, q_values): def _softmax_fn(q_values): return jnp.divide( jnp.exp(q_values * self.temperature), jnp.sum( jnp.exp(q_values * self.temperature), ), ) q = q_values.at[tuple(state)].get() return _softmax_fn(q)
Using the Softmax policy as a proxy for action probabilities, implement the expected Sarsa algorithm.