google-deepmind / reverb

Reverb is an efficient and easy-to-use data storage and transport system designed for machine learning research
Apache License 2.0
704 stars 92 forks source link

ReverbDataset on TPU #29

Closed weichseltree closed 1 year ago

weichseltree commented 3 years ago

What would be a working setup to use reverb on a TPU? When I try iterating over the dataset I always get the error:

NotFoundError: Op type not registered 'ReverbDataset' in binary running on ***.

Is there a special TPU software version with the Reverb ops and kernels?

sabelaraga commented 3 years ago

Hey Manuel,

The ReverbDataset op is CPU only, so you have to runt he ReverbDataset on CPU and you can use the distribute_datasets_from_function. It gets a dataset_fn that executes on CPU and transfers the data to the device. See examples of usage in TF Agents .

Let us know if it doesn't help.

Sabela.

weichseltree commented 3 years ago

Thank you very much for your reply. Unfortunately, I cannot get it to work. If you would like to find out what's going wrong, please take a look at this Colab Notebook i wrote that recreates my issue:

https://github.com/weichseltree/reverb_dataset_on_tpu/blob/main/reverb_dataset_on_tpu.ipynb

sabelaraga commented 3 years ago

Just to rule out other issues, can you check if you're have the same TF version in your VM and in the TPU (https://cloud.google.com/tpu/docs/version-switching?hl=en)?

I think in the colab you don't need with tf.device('/CPU:0'): inside experience_dataset_fn.

Thanks!

weichseltree commented 3 years ago

I tried your suggestion, it doesn't fix the error. I am using Tensorflow version 2.4.1 and the TPU is set to software version 2.4.1 as well. The code runs fine when the TPU is not connected.

sabelaraga commented 3 years ago

For some reason it seems that is trying to run the dataset op in the TPU. One thing I find weird in the colab is that the iterator is created outside of the strategy scope and the creation of the dataset is inside. Afaik, the creation of the dataset can be in or outside of the strategy scope (see the final section here).

weichseltree commented 3 years ago

If I play around with scope and device placement, nothing changes.

The error can occur in different places. In my notebook it fails on the .batch() operation but if I don't call it it will simply fail on the .prefetch() or next(dataset) operation. So the root of the problem must be in the initialization of the ReverbDataset or the way it interacts with the interleave() op.

sabelaraga commented 3 years ago

Hey Manuel, I think we found the issue and it has to do with the software architecture of the Cloud TPUs (https://cloud.google.com/tpu/docs/system-architecture#software_architecture): the imports are not available to the binary that runs on the TPU host.

The solution is to use tf.data.service (this is an example to launch it with GKE) and register the dataset there.

I'll try to find some time to run the full workflow, but leaving here the pointers in case you want to give it a try!

Sabela.

weichseltree commented 3 years ago

Okay that looks promising. I tried the setup with a test dataset with two processes:

  1. first process not connecting to the TPU and running the DispatchServeras well as the WorkerServer
  2. second process connecting to the TPU and trying to fetch the dataset via tf.data.experimental.service.from_dataset_id

The second process now gets stuck at the next(iterator) call. Without connecting the second process to the TPU it works fine, so I guess I will have to wait for your solution.

sabelaraga commented 3 years ago

Thanks for giving this a try! One question, is the DispatchServer running in the same colab as next(iterator)?

weichseltree commented 3 years ago

Since I am not aware that running multiple notebooks on the same Colab instance is possible, I just tried it on a GCP instance. The DispatchServer is not running in the same notebook/process as the next(iterator) call. The two parts above really just mean two python processes running on the same machine.

weichseltree commented 3 years ago

Any updates on this?

sabelaraga commented 3 years ago

Two quick questions:

weichseltree commented 3 years ago
sabelaraga commented 3 years ago

And other couple of questions (to see if this is a problem of tf.data service with reverb or if there is something else):

weichseltree commented 3 years ago

I actually didn't use the ReverbDataset in my tests. The code I use is this: First notebook:

import tensorflow as tf

dispatcher = tf.data.experimental.service.DispatchServer()
dispatcher_address = dispatcher.target.split("://")[1]
worker = tf.data.experimental.service.WorkerServer(
    tf.data.experimental.service.WorkerConfig(
        dispatcher_address=dispatcher_address))
dataset = tf.data.Dataset.range(10)
dataset_id = tf.data.experimental.service.register_dataset(
    dispatcher.target, dataset)
print(dataset_id)
print(dispatcher.target)

It then prints the dataset_id, usually 1000, and the target, something like 'grpc://localhost:44771'. I then use another notebook to connect to it like this:

import os
import tensorflow as tf

print("Tensorflow version " + tf.__version__)
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver('<my-tpu-node-name>')  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.TPUStrategy(tpu)

dataset = tf.data.experimental.service.from_dataset_id(
    processing_mode="parallel_epochs",
    service='grpc://<my-gcp-local-ip-address>:<dispatch-server-port>', # something like 'grpc://10.156.0.34:44771'
    dataset_id=<dataset-id>, # usually 1000
    element_spec=tf.TensorSpec(shape=(), dtype=tf.int64))

iterator = iter(dataset)
print(next(iterator))

The code runs fine if I don't connect to the TPU but gets stuck if I connect to the TPU. So there seems to be an issue with the tf.data.service on the TPU. Both notebooks and the TPU node run tensorflow 2.4.1.

sabelaraga commented 3 years ago

Thanks for trying this out and confirming the details.

I think that the problem is still that the TPU worker cannot connect to the notebook runtime, and it's necessary to run the tf.data service separately (not in a notebook). This tutorial shows how to run it on a GKE cluster, and TPU should be able to communicate with it.

sabelaraga commented 3 years ago

Atually, it might be easier to make sure the Notebook runs in the same VPC network as the TPU (see how to configure the Notebook instance here).

weichseltree commented 3 years ago

I eventually figured out how to connect the TPU to the DispatchServer using a non-reverb dataset by opening the ports involved.

However, when I switch back to the ReverbDataset, the original issue reappears:

tensorflow.python.framework.errors_impl.NotFoundError: Failed to register dataset: Op type not registered 'ReverbDataset' in binary running on <gcp-instance-name>. Make sure the Op and Kernel are registered in the binary running in this process.

This code can be used to reproduce the issue, I didn't even connect to any TPU.

import tensorflow as tf
import acme
import acme.datasets

import logging
logging.getLogger('tensorflow').setLevel(logging.DEBUG)

dispatcher = tf.data.experimental.service.DispatchServer(
    tf.data.experimental.service.DispatcherConfig(port=5050))
dispatcher_address = dispatcher.target.split("://")[1]
worker = tf.data.experimental.service.WorkerServer(
    tf.data.experimental.service.WorkerConfig(
        dispatcher_address=dispatcher_address))

dataset = acme.datasets.make_reverb_dataset(
    'localhost:9999',
    environment_spec=acme.specs.EnvironmentSpec(
                          observations=tf.TensorSpec((), dtype=tf.float32),
                          actions=tf.TensorSpec((), dtype=tf.float32),
                          rewards=tf.TensorSpec((), dtype=tf.float32),
                          discounts=tf.TensorSpec((), dtype=tf.float32)),
    table='training',
    batch_size=1,
    prefetch_size=tf.data.experimental.AUTOTUNE,
    sequence_length=10)

tf.data.experimental.service.register_dataset(
    dispatcher.target, dataset)

input()

tensorflow: 2.4.1 dm-acme: 0.2.0 dm-reverb: 0.2.0

sabelaraga commented 3 years ago

Hey Manuel, sorry for the late reply. Do you have the notebook code used to access the dataset and reproduce the error?

Thanks!

thisiscam commented 3 years ago

Hi,

I just noticed this thread. I would like some clarification on this too: in the setup by @weichseltree, are you using tensorflow (instead of, for example, JAX)? The notebook https://github.com/weichseltree/reverb_dataset_on_tpu/blob/main/reverb_dataset_on_tpu.ipynb link is dead unfortunately so I can only guess...

Based on my understanding per this comment, it's OK to run custom tensorflow's CPU ops on the VM attached to a TPU. Perhaps it's possible to build a standalone TF custom op for the reverb dataset, and then load it via tf.load_op_library (if this is not done already, that is)? Please correct me if I'm wrong though!

Also to my understanding, this won't be useful unless one is using TensorFlow, or using an alpha release that allows direct ssh access to the TPU VM.

sabelaraga commented 3 years ago

Hey, IIRC, the colab was all in TF.

The reverb dataset is already a TF custom op that runs on CPU, but the binary running on the TPU worker doesn't have access to it.

But you're right, the new architecture should solve the problem.

thisiscam commented 3 years ago

but the binary running on the TPU worker doesn't have access to it.

I see. Thanks for confirming.

thisiscam commented 3 years ago

Just curious, will it work if you put the shared library file in a GCS bucket though?

sabelaraga commented 3 years ago

Not sure what you mean. In this case, if you run a tf.data.service on a separate server, the TPU worker should fetch tensors with a tf.data.service client without having to know that they come from a Reverb Dataset.

thisiscam commented 3 years ago

I think TPUs can access cloud storage bucket: https://cloud.google.com/tpu/docs/storage-buckets. I'm thinking that perhaps one can put the compiled shared library file for the custom reverb dataset op (e.g. reverb_custom_ops.so or some similar name) into a GCS bucket, and somehow direct the TPU VM to load the shared library.

I suppose this sounds like it is probably disallowed...

sabelaraga commented 3 years ago

Yeah, I'm not an expert on how the TPU VMs access Cloud storage, but I don't think the TPU worker could load that.

ebrevdo commented 3 years ago

I've asked Zak Stone to weigh in on this.

ebrevdo commented 3 years ago

Looks like the new single-VM TPU service is now available. You should be able to connect to a TPU host directly and run your python code there. A couple of nice advantages include being able to colocate your reverb server and the learner in the same process, allowing Reverb to bypass the RPC part and do zero-copy transfer between the two.

manuel-weichselbaum commented 3 years ago

Hello everyone. It's been a while, sorry for abandoning this issue. I'm now working on the setup just like @ebrevdo described. I am on a Cloud TPU VM which comes with tensorflow 2.6.0 preinstalled. When I try to adapt the /usr/share/tpu/tensorflow/simple_example.py file by adding import reverb after import tensorflow as tf I get the following error:

Traceback (most recent call last):
  File "simple_example.py", line 4, in <module>
    import reverb
  File "/usr/local/lib/python3.8/dist-packages/reverb/__init__.py", line 27, in <module>
    from reverb import item_selectors as selectors
  File "/usr/local/lib/python3.8/dist-packages/reverb/item_selectors.py", line 19, in <module>
    from reverb import pybind
  File "/usr/local/lib/python3.8/dist-packages/reverb/pybind.py", line 1, in <module>
    import tensorflow as _tf; from .libpybind import *; del _tf
ImportError: libtensorflow_framework.so.2: cannot open shared object file: No such file or directory

I cannot find the file in the tensorflow install directory, this seems to be a change from 2.5 to 2.6 maybe?

Simply copying the one from tensorflow 2.5 results in another error:

Traceback (most recent call last):
  File "simple_example.py", line 4, in <module>
    import reverb
  File "/usr/local/lib/python3.8/dist-packages/reverb/__init__.py", line 27, in <module>
    from reverb import item_selectors as selectors
  File "/usr/local/lib/python3.8/dist-packages/reverb/item_selectors.py", line 19, in <module>
    from reverb import pybind
  File "/usr/local/lib/python3.8/dist-packages/reverb/pybind.py", line 1, in <module>
    import tensorflow as _tf; from .libpybind import *; del _tf
ImportError: /usr/local/lib/python3.8/dist-packages/reverb/libpybind.so: undefined symbol: _ZN4absl12lts_202103245MutexD1Ev

Do I have to compile reverb on the Cloud TPU VM? Thank you in advance for pointing me in the correct direction :)

ebrevdo commented 3 years ago

Looks like you'll need to use a version of reverb built for your version of TF. @tfboyd just released a new minor version to match TF 2.6.0. Give that a try?

tfboyd commented 3 years ago

0.4.0 was compiled with TF 2.6.0 and should match up as expected. https://pypi.org/project/dm-reverb/#history

manuel-weichselbaum commented 3 years ago

Okay, I think I see what's happening. A Cloud TPU VM comes with a custom tf-nightly 2.6.0 build. This version will give the errors above. When I install the default tensorflow 2.6.0 and reverb 0.4.0 everything imports nicely. However, then the TPU ops are not available when I initialize the system as shown in the tutorial https://cloud.google.com/tpu/docs/tensorflow-quickstart-tpu-vm: InvalidArgumentError: No OpKernel was registered to support Op 'ConfigureDistributedTPU' (this also happens with the dm-reverb-nightly[tensorflow] version)

So I guess I have to build reverb to match the pre-installed version. Is that possible? How would I deviate from the build from source tutorial in that scenario?

MorganeAyle commented 3 years ago

Hello,

I am also getting the following error when trying to sample from the ReplayDataset after having initialized a TPU: tensorflow.python.framework.errors_impl.NotFoundError: 'ReverbDataset' is neither a type of a primitive operation nor a name of a function registered in binary running on n-9f826cf4-w-0. Make sure the operation or function is registered in the binary running in this process. [Op:DeleteIterator]

I am using tensorflow 2.4.1 and reverb 0.2.0 versions. Everything works fine when sampling and training on a GPU with the same code. I tried the above mentioned suggestions but with no luck. Was anyone able to solve this issue?

sabelaraga commented 1 year ago

Since the last report on this issue is from 2021, and things have changed significantly since then, I'm going to close it. Please reopen (or create a new issue) if you experience the same problem.