hr0nix / dejax

Accelerated replay buffers in JAX
Apache License 2.0
39 stars 4 forks source link
jax reinforcement-learning


An implementation of replay buffer data structure in JAX. Operations involving dejax replay buffers can be jitted and run on both CPU and GPU.

Package contents

How to use dejax replay buffers

import dejax

First, instantiate a buffer object. Buffer objects don't have state but rather provide methods to initialize and manipulate state.

buffer = uniform_replay(max_size=10000)

Having a buffer object, we can initialize the state of the replay buffer. For that we would need a prototype item that will be used to determine the structure of the storage. The prototype item must have the same structure and tensor shapes as the items that will be stored in the buffer.

buffer_state = buffer.init_fn(item_prototype)

Now we can fill the buffer:

for item in items:
    buffer_state = buffer.add_fn(buffer_state, item)

And sample from it:

batch = buffer.sample_fn(buffer_state, rng, batch_size)

Or apply an update op to the items in the buffer:

def item_update_fn(item):
    # Possibly update an item
    return item
buffer_state = buffer.update_fn(buffer_state, item_update_fn)

Donating the buffer state to update it in-place

To benefit from being able to manipulate replay buffers in jit-compiled code, you'd probably want to make a top-level train_step function, which updates both the train state and the replay buffer state given a fresh batch of trajectories:

train_state, replay_buffer_state = jax.jit(train_step, donate_argnums=(1,))(
    train_state, replay_buffer_state, trajectory_batch

It's important to specify donate_argnums in the call to jax.jit to allow JAX to update the replay buffer state in-place. Without donate_argnums, any modification of replay buffer state will force JAX to make a copy of the state, which is likely to destroy all performance benefits. You can read more about buffer donation in JAX here.

Note that buffer donation is not supported on CPU!