hr0nix / dejax

Accelerated replay buffers in JAX
Apache License 2.0
39 stars 4 forks source link

More efficient add_batch_fn implementation #3

Open seawee1 opened 1 year ago

seawee1 commented 1 year ago

Really much appreciate this repository. However, for non-trivial batch sizes, the add_batch_fn caused a lot of slowdowns due to having to add every batch element sequentially.

In case somebody is interested, I alleviated this issue via this implementation. First, a new util function to add a batch via a single operation to the data storage.

def set_pytree_batch_items(tree_batch, index, trees):
    return jax.tree_util.tree_map(
        lambda tb, t: jax.lax.dynamic_update_slice(tb, t, (index, 0)),
        tree_batch, trees,
    )

Second, the outer function:

def add_batch_fn(state: UniformReplayBufferState, batch: ItemBatch) -> UniformReplayBufferState:
    buffer = state.storage

    insert_pos = buffer.head
    new_data = utils.set_pytree_batch_items(buffer.data, insert_pos, batch)
    new_head = (insert_pos + batch[0].shape[0]) % circular_buffer.max_size(buffer)

    new_tail = jax.lax.select(
        buffer.full,
        on_true=0,  # Changed, due to the way `jax.lax.dynamic_update_slice` behaves inside `set_pytree_batch_items`
        on_false=buffer.tail,
    )
    new_full = new_head == new_tail

    return state.replace(storage=buffer.replace(data=new_data, head=new_head, tail=new_tail, full=new_full))

I thought about doing a pull request, but I'm too lazy right now to do a clean adaptation of this quick fix to the code base.

hr0nix commented 1 year ago

Nice! Unfortunately I don't have a lot of time to support this repo atm, but I'll try to integrate your idea.