google-deepmind / reverb

Reverb is an efficient and easy-to-use data storage and transport system designed for machine learning research
Apache License 2.0
700 stars 93 forks source link

Best way to load dataset into reverb tables. #72

Closed ethanluoyc closed 2 years ago

ethanluoyc commented 2 years ago

Hi,

I am trying to populate a Reverb table with some NumPy arrays and was wondering what would be the most efficient way to do this.

I am currently doing something like

def load_dataset_into_reverb(
    replay_client, dataset: types.Transition, table_name: str, num_keep_alive_refs=2
):
    """Load offline dataset into reverb"""
    logging.info("Populating reverb with offline data")
    table_size = replay_client.server_info()[table_name].max_size
    dataset_size = tree.flatten(dataset)[0].shape[0]
    if table_size < dataset_size:
        raise ValueError(
            f"Unable to insert dataset of size {dataset_size} into table with size {table_size}",
        )
    with replay_client.trajectory_writer(
        num_keep_alive_refs=num_keep_alive_refs
    ) as writer:
        for i in range(dataset_size):
            blob = types.Transition(
                observation=dataset.observation[i],
                action=dataset.action[i],
                reward=dataset.reward[i],
                discount=dataset.discount[i],
                next_observation=dataset.next_observation[i],
            )
            writer.append(blob)
            item = types.Transition(
                observation=writer.history.observation[-1],
                action=writer.history.action[-1],
                reward=writer.history.reward[-1],
                discount=writer.history.discount[-1],
                next_observation=writer.history.next_observation[-1],
            )
            writer.create_item(
                table=table_name,
                priority=1.0,
                trajectory=item,
            )
    logging.info("Populated reverb with offline data")

This loops over rows of the dataset and creates individual items. It works correctly but I have encountered some performance issues when the dataset is large, I am wondering if the Reverb authors have some recommendations for the most efficient way to import the data into reverb.

Thanks!

Best, Yicheng

acassirer commented 2 years ago

Hi Yicheng,

Your implementation looks pretty solid to me. The only thing that I would change is to default num_keep_live_refs to 1 instead of 2 (but this is unlikely to change your performance issues).

The only advice I would give you is to try to parallelise the insertion using multiple workers. You could maybe do something along these lines:

from concurrent import futures
import operator

from absl import logging
import reverb

def load_dataset_into_reverb(replay_client: reverb.Client,
                             dataset: types.Transition,
                             table_name: str,
                             num_keep_alive_refs: int = 1,
                             num_workers: int = 10):
    """Load offline dataset into reverb"""
    logging.info("Populating reverb with offline data")

    table_size = replay_client.server_info()[table_name].max_size
    dataset_size = tree.flatten(dataset)[0].shape[0]
    if table_size < dataset_size:
        raise ValueError(
            f"Unable to insert dataset of size {dataset_size} into table with "
            f"size {table_size}")

    def _run_worker(offset: int):
        with replay_client.trajectory_writer(num_keep_alive_refs) as writer:
            for i in range(offset, dataset_size, num_workers):
                blob = tree.map_structure(operator.itemgetter(i), dataset)
                writer.append(blob)

                item = tree.map_structure(operator.itemgetter(-1), dataset)
                writer.create_item(
                    table=table_name,
                    priority=1.0,
                    trajectory=item,
                )

    with futures.ThreadPoolExecutor(num_workers) as executor:
        # Converting this to a list forces the futures to be resolved which
        # mean that any error raised by the workers is propagated here.
        list(executor.map(_run_worker, range(num_workers)))

    logging.info("Populated reverb with offline data")

Cheers, Albin

ethanluoyc commented 2 years ago

Thank you! This is really useful!