RPegoud / jym

JAX implementation of RL algorithms and vectorized environments
MIT License
32 stars 1 forks source link

Implement expected SARSA #3

Closed RPegoud closed 10 months ago

RPegoud commented 10 months ago

Using the Softmax policy as a proxy for action probabilities, implement the expected Sarsa algorithm.

class Softmax_policy(BasePolicy):
    def __init__(self, temperature) -> None:
        self.temperature = temperature

    @partial(jit, static_argnums=(0))
    def call(self, state, q_values):
        def _softmax_fn(q_values):
            return jnp.divide(
                jnp.exp(q_values * self.temperature),
                jnp.sum(
                    jnp.exp(q_values * self.temperature),
                ),
            )

        q = q_values.at[tuple(state)].get()
        return _softmax_fn(q)