Closed fuyw closed 2 years ago
@fuyw I think there is some bug on Windows otherwise: https://github.com/ikostrikov/jaxrl/pull/18
That's cool! Is it this implementation? I will take a look.
Thanks for the reply. Yes it is, and I just refactored the code according to the flax official examples.
For simplicity, I replaced the tfd
to distrax
, and this does not matters in my experiments.
Sorry Ilya, I found a bug in my previous implementation. I used a jax.device_put()
when sampling from the buffer, which wastes time. When I fixed this bug, the throughput is close to this implementation now.
https://github.com/ikostrikov/implicit_q_learning/blob/09d700248117881a75cb21f0adb95c6c8a694cb2/policy.py#L66
Hi Ilya,
Many thanks for the nice work. I have a question of the
sample_actions()
function, why do we need the_sample_actions()
? Isn't it redundant?Maybe we can simply:
Further, I tried to reimplement IQL with
TrainState
. I found that useTrainState
is slower than this implementation (~100-200 fps).