google-deepmind / reverb

Reverb is an efficient and easy-to-use data storage and transport system designed for machine learning research
Apache License 2.0
700 stars 93 forks source link

Can not scale up with multiple servers #58

Closed bingykang closed 3 years ago

bingykang commented 3 years ago

Hi @acassirer

During my benchmarking, I found a new issue, thus I put it here for cleaner discussion. The main problem is: adding more reverb servers can not contribute to higher throughput.

I conducted two series of experiments.

<1>. Varying `--num_servers` from `1` to `4`. <2>. Keeping `--num_servers` as `1`, and launching multiple experiments simultaneously at the same time. Or in other words, running multiple servers independently. The FPS (K) are summarized as follows, | servers | <1> | <2> | | ------- |:------:| :------:| | 1 | 224 | 224| | 2 | 210 + 198 = 408| 274 | | 2 | 197 + 182 + 172 = 551| 271 | | 4 | 183 + 165 + 159 + 136 = 643 | 268 | The code is also attached for you infromation. ```python """Benchmarking reverb's multi-server performance. python3 benchmark/bench_reverb_multiserver.py --lp_launch_type=local_mp --num_workers 32 --num_servers 1 """ import itertools import time import numpy as np import launchpad as lp import reverb import tensorflow as tf import tree from absl import app, flags TEST_TABLE = 'queue' WRITE_IDX = 0 FLAGS = flags.FLAGS flags.DEFINE_integer('seq_length', 20, 'length of a trajectory') flags.DEFINE_integer('batch_size', 64, 'the batch size to sample trajs') flags.DEFINE_integer('num_workers', 64, 'number of producers') flags.DEFINE_integer('num_servers', 4, 'number of producers') class Spec: def __init__(self, shape, dtype): self.shape = shape self.dtype = dtype spec = {'obs': Spec((84, 84, 4), np.uint8), 'action': Spec((), np.int32)} def get_queue_size(clients: reverb.Client) -> int: queue_size = 0 for client in clients: table_info = client.server_info()[TEST_TABLE] queue_size += ( table_info.rate_limiter_info.insert_stats.completed - table_info.rate_limiter_info.sample_stats.completed) return queue_size class Producer: """A producer generate fake data.""" def __init__(self, client, seq_length=1): self._client = client self._seq_length = seq_length def run(self): writer = self._client.trajectory_writer( num_keep_alive_refs=self._seq_length, get_signature_timeout_ms=300_000, ) def randint(x): return np.random.randint(255, size=x.shape, dtype=x.dtype) for cnt in itertools.count(1): data = tree.map_structure(randint, spec) writer.append(data) if cnt % self._seq_length == 0: traj = tree.map_structure(lambda x: x[-self._seq_length:], writer.history) writer.create_item(TEST_TABLE, priority=1, trajectory=traj) writer.flush(self._seq_length) # pylint: disable=pylint(protected-access) for i, col in enumerate(writer._column_history): writer._column_history[i]._data_references = \ col._data_references[-self._seq_length:] if cnt % 1000 == 0: print(cnt) class Consumer: """A consumer to read data from the reverb server.""" def __init__(self, clients, seq_length, batch_size=1) -> None: self._clients = clients self._seq_length = seq_length self._batch_size = batch_size self._cps = cps = 16 # number of parallel calls per server self._addresses = [ clients[i // cps].server_address for i in range(cps * len(clients)) ] def _make_dataset(self): # Get signature info info = self._clients[0].server_info() shapes = tree.map_structure(lambda x: x.shape, info[TEST_TABLE].signature) dtypes = tree.map_structure(lambda x: x.dtype, info[TEST_TABLE].signature) def _inner_make_dataset(server_address): ds = reverb.TrajectoryDataset( server_address=server_address, table=TEST_TABLE, max_in_flight_samples_per_worker=100, shapes=shapes, dtypes=dtypes, ) return ds num_parallel_calls = len(self._addresses) ds = tf.data.Dataset.from_tensor_slices(self._addresses) ds = ds.interleave( map_func=_inner_make_dataset, cycle_length=num_parallel_calls, num_parallel_calls=num_parallel_calls, deterministic=False, ) ds = ds.batch(self._batch_size) ds.prefetch(-1) ds = iter(ds) return ds def run(self): ds = self._make_dataset() log_interval = 20 step_size = self._batch_size * self._seq_length * 4 t0 = time.time() for i in itertools.count(1): next(ds) if i % log_interval == 0: fps = int(i * step_size / (time.time() - t0)) print('| FPS: {}, QSize: {}'.format(fps, get_queue_size(self._clients))) def main(_): seqlen, bs, ns = FLAGS.seq_length, FLAGS.batch_size, FLAGS.num_servers sig = tree.map_structure( lambda x: tf.TensorSpec(shape=(seqlen,) + x.shape, dtype=x.dtype), spec, ) # TODO def make_queue(): return [reverb.Table.queue(name=TEST_TABLE, max_size=10000, signature=sig)] program = lp.Program('BenchRerverb') clients = [] with program.group('reverb'): for _ in range(ns): clients.append(program.add_node(lp.ReverbNode(make_queue))) with program.group('Consumer'): program.add_node( lp.CourierNode( Consumer, clients=clients, seq_length=seqlen, batch_size=bs)) with program.group('Producer'): for i in range(FLAGS.num_workers): program.add_node( lp.CourierNode(Producer, client=clients[i % ns], seq_length=seqlen)) lp.launch(program) if __name__ == '__main__': app.run(main) ```
acassirer commented 3 years ago

Hey,

I'm a bit confused with the question tbh. Why would running multiple servers increase throughput? It looks like you are just spreading the producer load across the servers when running in the multi server setup so I would expect 1 or 4 servers to perform basically the same.

Sorry if I'm misunderstanding the question.

bingykang commented 3 years ago

Sorry for the misleading. Let me clarify that what I want to do is to increase the throughput.

I first tried adding more tables to one server, but it does no help too much. Then I turned to set up multiple servers.

Probably I should ask what's the best practice to increase the throughput?

acassirer commented 3 years ago

Ok gotcha.

I'd say that it is unlikely that the server is the bottleneck in this setup given the SPS. It looks like you have a batch size of 64 and doing some ~250 steps/s. That adds up to some ~15k samples per second which is below what a single server (and table) can handle. I don't have any automatic benchmarks for the open source distribution but our internal version (which is more or less identical but I don't know if our hardware or OS is magic) is showing ~30k samples/s in the queue benchmark that uses a payload of 40 kB

It is worth noting that that setup is distributed across 36 machines:

So it doesn't sound unreasonable that you are unable to achieve more than 15k QPS when using a single consumer.

Sidenote: There is a risk that you are actually just constrained by the speed of your consumers. I suspect that it spends the vast majority of its CPU cycles on np.random.randint so it could be worth testing if sampling the data upfront and then just cycling through it makes a difference to the overall speed.

acassirer commented 3 years ago

Another piece of good news is that we recently (yesterday) submitted some changes which are looking very promising. This is especially true for the queue setup were we are seeing a speedup of more than 2x so hopefully that will be reflected in your benchmarks too.

You can find more details in the recent commits by @qstanczyk

bingykang commented 3 years ago

There is a risk that you are actually just constrained by the speed of your consumers.

I'd say this is the real problem. In my benchmarks, the queue is always full, which means the consumer is not fast enough to consume the generated data. So basically, the question becomes how to read feater from the server?

Do you have any thoughts on this?

acassirer commented 3 years ago

As previously mentioned, there is probably limited upside here until the new optimisations are live but the one thing that you can try is to run multiple consumers. I realise that it may not be useful solution but it will give some information.

Taking another look at your consumer code I think the issue might just be how the dataset is created. It will only use one connection per server which will most likely not be enough to reach full speed. One thing to try before anything else would be to do something like:


  def _make_dataset(self):

    # Get signature info
    info = self._clients[0].server_info()
    shapes = tree.map_structure(lambda x: x.shape, info[TEST_TABLE].signature)
    dtypes = tree.map_structure(lambda x: x.dtype, info[TEST_TABLE].signature)

    def _inner_make_dataset(server_address):
      ds = reverb.TrajectoryDataset(
          server_address=server_address,
          table=TEST_TABLE,
          max_in_flight_samples_per_worker=100,
          shapes=shapes,
          dtypes=dtypes,
      )
      return ds

    ds = tf.data.Dataset.from_tensor_slices(self._addresses)
    ds = ds.repeat()
    ds = ds.interleave(
        map_func=_inner_make_dataset,
        cycle_length=tf.data.AUTOTUNE,
        num_parallel_calls=tf.data.AUTOTUNE,
        block_length=1,
        deterministic=False,
    )

    ds = ds.batch(self._batch_size)
    ds.prefetch(2)  

    return ds.as_numpy_iterator()
bingykang commented 3 years ago

I tried with two consumers and three consumers. Compared to one consumer (FPS 220K), the FPS increased to 199+191 = 390K and 150+151+159=460K respectively.

Then I tried building a dataset with your code for one consumer, the FPS drops to 190 K actually.

acassirer commented 3 years ago

Then we've learned that the bottleneck isn't the server so sharding the table or running multiple instances will not help.

One thing you could try is to check what happens if you a multi threaded consumer, that is create multiple datasets and iterate from them in the same thread (i.e. not different processes). If that works then you could hack something together for feeding the "real" consumer using threadpool of datasets. If it doesn't improve the performance compared to a single consumer then it would point towards limitations in the Python runtime (i.e. GIL contention).

Either way the problem evidently isn't really on the "Reverb side" but rather the generic how to speed up tf data so I'd advice you to redirect your question/dig more into that realm.

acassirer commented 3 years ago

@bingykang Since the limitation is outside of Reverb I'm going to go ahead and close this issue. Please feel free to reopen it if you disagree with my conclusions.

Also if you end up finding some good solutions on how to speed up the input pipeline then it would be fantastic if you could just drop a comment here so future readers can learn from your findings.