ikostrikov / implicit_q_learning

MIT License
226 stars 38 forks source link

A question about the `sample_actions()` #5

Closed fuyw closed 2 years ago

fuyw commented 2 years ago

https://github.com/ikostrikov/implicit_q_learning/blob/09d700248117881a75cb21f0adb95c6c8a694cb2/policy.py#L66

Hi Ilya,

Many thanks for the nice work. I have a question of the sample_actions() function, why do we need the _sample_actions()? Isn't it redundant?

Maybe we can simply:

@functools.partial(jax.jit, static_argnames=('actor_def'))
def sample_actions(rng, actor_def, actor_params, observations, temperature):
    dist = actor_def.apply({'params': actor_params}, observations, temperature)
    rng, key = jax.random.split(rng)
    return rng, dist.sample(seed=key)

Further, I tried to reimplement IQL with TrainState. I found that use TrainState is slower than this implementation (~100-200 fps).

ikostrikov commented 2 years ago

@fuyw I think there is some bug on Windows otherwise: https://github.com/ikostrikov/jaxrl/pull/18

That's cool! Is it this implementation? I will take a look.

fuyw commented 2 years ago

Thanks for the reply. Yes it is, and I just refactored the code according to the flax official examples.

For simplicity, I replaced the tfd to distrax, and this does not matters in my experiments.

fuyw commented 2 years ago

Sorry Ilya, I found a bug in my previous implementation. I used a jax.device_put() when sampling from the buffer, which wastes time. When I fixed this bug, the throughput is close to this implementation now.