google-deepmind / reverb

Reverb is an efficient and easy-to-use data storage and transport system designed for machine learning research
Apache License 2.0
704 stars 92 forks source link

No improved performance for inserting into Reverb via multithreading #75

Closed ethanluoyc closed 3 years ago

ethanluoyc commented 3 years ago

Hi, This is a follow-up from #72

I have done some profiling using the suggested solution from #72, however, I found that there is no speed-up using the suggested solution.

Here is complete repro script, I have modified a few things from the answer in #72 (including some fixes to the solution as well as making the script self-contained)

from concurrent import futures
import operator
import time

from absl import app
from absl import flags
from absl import logging
from acme import specs
from acme import types
from acme.adders import reverb as adders_reverb
from acme.testing import fakes
import numpy as np
import reverb
from reverb import rate_limiters
import tensorflow as tf
import tree

FLAGS = flags.FLAGS

def load_dataset_into_reverb(
    replay_client: reverb.Client,
    dataset: types.Transition,
    table_name: str,
    num_keep_alive_refs: int = 1,
    num_workers: int = 10,
):
    """Load offline dataset into reverb"""
    logging.info("Populating reverb with offline data")

    table_size = replay_client.server_info()[table_name].max_size
    dataset_size = tree.flatten(dataset)[0].shape[0]
    if table_size < dataset_size:
        raise ValueError(
            f"Unable to insert dataset of size {dataset_size} into table with "
            f"size {table_size}"
        )

    num_items_per_worker = dataset_size // num_workers

    def _run_worker(offset: int):
        with replay_client.trajectory_writer(num_keep_alive_refs) as writer:
            start_idx = offset * num_items_per_worker
            end_idx = min(dataset_size, (offset + 1) * num_items_per_worker)
            for i in range(start_idx, end_idx):
                blob = tree.map_structure(operator.itemgetter(i), dataset)
                writer.append(blob)
                item = tree.map_structure(operator.itemgetter(-1), writer.history)
                writer.create_item(
                    table=table_name,
                    priority=1.0,
                    trajectory=item,
                )

    start_time = time.time()
    with futures.ThreadPoolExecutor(num_workers) as executor:
        # Converting this to a list forces the futures to be resolved which
        # mean that any error raised by the workers is propagated here.
        list(executor.map(_run_worker, range(num_workers)))
    logging.info(
        "Populated reverb with offline data, time elapsed %.2f seconds",
        time.time() - start_time,
    )

def main(_):
    # Disable TF GPU
    tf.config.set_visible_devices([], "GPU")

    # environment = utils.make_environment("hopper-medium-v0", seed=0)
    environment = fakes.ContinuousEnvironment(action_dim=3, observation_dim=11)
    spec = specs.make_environment_spec(environment)

    dataset_size = int(1e5)
    dataset = types.Transition(
        observation=np.zeros(
            (dataset_size,) + spec.observations.shape, spec.observations.dtype
        ),
        next_observation=np.zeros(
            (dataset_size,) + spec.observations.shape, spec.observations.dtype
        ),
        reward=np.zeros((dataset_size,) + spec.rewards.shape, spec.rewards.dtype),
        action=np.zeros((dataset_size,) + spec.actions.shape, spec.actions.dtype),
        discount=np.zeros((dataset_size,) + spec.discounts.shape, spec.discounts.dtype),
    )
    replay_table = reverb.Table(
        name=adders_reverb.DEFAULT_PRIORITY_TABLE,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        max_size=int(1e6),
        rate_limiter=rate_limiters.MinSize(1),
        signature=adders_reverb.NStepTransitionAdder.signature(environment_spec=spec),
    )
    replay_server = reverb.Server([replay_table], port=None)
    replay_client = reverb.Client(f"localhost:{replay_server.port}")
    # Load offline dataset into Reverb

    load_dataset_into_reverb(replay_client, dataset, replay_table.name, num_workers=10)

if __name__ == "__main__":
    FLAGS.logtostderr = True
    app.run(main)

It seems that the total time it takes for the loading to finish is the same whether I am using 10 workers or 1 worker. I am not sure what's happening, could this be an issue with Python's GIL?

qstanczyk commented 3 years ago

I think it is needed to profile and see what is happening. The program might be bottlenecked by Python's GIL. A quick check you could do is to look at CPU usage as the number of workers increases. If CPU usage with 1 vs more workers is similar it could be GIL.

ethanluoyc commented 3 years ago

It seems that the CPU usage does increase with more workers, so might not be GIL.

I did some basic profiling and here is what I got:

Mon Oct 11 09:53:26 2021    script.prof

         1778481 function calls (1737841 primitive calls) in 129.245 seconds

   Ordered by: cumulative time
   List reduced from 9010 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   3674/1    0.032    0.000  129.252  129.252 {built-in method builtins.exec}
        1    0.000    0.000  129.252  129.252 test_load_reverb.py:1(<module>)
        1    0.000    0.000  127.922  127.922 /home/yicheng/virtualenvs/orax/lib/python3.8/site-packages/absl/app.py:277(run)
        1    3.485    3.485  127.888  127.888 /home/yicheng/virtualenvs/orax/lib/python3.8/site-packages/absl/app.py:238(_run_main)
        1    0.002    0.002  124.403  124.403 test_load_reverb.py:67(main)
        1    0.001    0.001  124.366  124.366 test_load_reverb.py:22(load_dataset_into_reverb)
      101  124.358    1.231  124.358    1.231 {method 'acquire' of '_thread.lock' objects}
       23    0.000    0.000  124.358    5.407 /usr/lib/python3.8/threading.py:270(wait)
       11    0.000    0.000  123.257   11.205 /usr/lib/python3.8/concurrent/futures/_base.py:612(result_iterator)
       10    0.000    0.000  123.257   12.326 /usr/lib/python3.8/concurrent/futures/_base.py:416(result)
  2933/10    0.007    0.000    1.330    0.133 <frozen importlib._bootstrap>:986(_find_and_load)
  2932/10    0.006    0.000    1.330    0.133 <frozen importlib._bootstrap>:956(_find_and_load_unlocked)
  2854/10    0.006    0.000    1.329    0.133 <frozen importlib._bootstrap>:650(_load_unlocked)
  4878/10    0.001    0.000    1.329    0.133 <frozen importlib._bootstrap>:211(_call_with_frames_removed)
  2626/10    0.004    0.000    1.329    0.133 <frozen importlib._bootstrap_external>:842(exec_module)
 10246/21    0.006    0.000    1.327    0.063 <frozen importlib._bootstrap>:1017(_handle_fromlist)
   2120/9    0.003    0.000    1.327    0.147 {built-in method builtins.__import__}
        1    0.000    0.000    1.213    1.213 /home/yicheng/virtualenvs/orax/lib/python3.8/site-packages/acme/adders/reverb/__init__.py:16(<module>)
        1    0.000    0.000    1.103    1.103 /usr/lib/python3.8/concurrent/futures/_base.py:583(map)
        1    0.000    0.000    1.103    1.103 /usr/lib/python3.8/concurrent/futures/_base.py:608(<listcomp>)

It seems that a lot of time is spent on acquiring the thread lock somewhere but I am not sure where that is happening. Hope this is useful for identifying the issue.

qstanczyk commented 3 years ago

Hard to tell without a stack trace showing where lock is accessed. Can you get it from the profile?

ethanluoyc commented 3 years ago

I am not that familiar with cProfile so I am not entirely sure what's happening.

I am attaching the output generated from running python3 -m cProfile -o ... if that's useful: script.prof.zip.

qstanczyk commented 3 years ago

The main source of slowness is the map between append and create_item: tree.map_structure(operator.itemgetter(-1), writer.history)

Adding more threads doesn't help as computations run sequentially due to GIL. Actually - more threads make it run slower due to context switches. cProfile you provided shows the main thread waiting on the workers to complete, so it doesn't tell much.

ethanluoyc commented 3 years ago

I see. Is there a good workaround to speed up the map structure there without sacrificing generality? (I.e. allowing insertion of arbitrary stacked items)

In terms of the GIL, I do not have a clue of how to solve that. Maybe it's possible to write some Cython code that releases the GIL when running the for loop. In that case, is the client thread safe?

qstanczyk commented 3 years ago

Easiest way is to use multiple processes to write data to Reverb.

ethanluoyc commented 3 years ago

You are probably right. I will go with that then. Thanks!