tensorflow / hub

A library for transfer learning by reusing parts of TensorFlow models.
https://tensorflow.org/hub
Apache License 2.0
3.49k stars 1.67k forks source link

Bug: Universal Sentence Encoder (USE) is not compatible with tf.distribute #856

Closed jeisinge closed 2 years ago

jeisinge commented 2 years ago

What happened?

The USE DAN model is an efficient embedder for short phrases. And, it trains well on a single GPU. However, it fails to train on a multi-GPU with tf.distribute.

A couple of previous defects have been reported and closed out, but the issue remains. The closest one I found is #515 . The workaround proposed by RobRomijnders is to utilize strategy.run(), however, I don't understand how to do this with Keras. Specifically, calling this method returned a PerReplica object --- I don't know how to merge it back to a regular Keras layer. See https://github.com/tensorflow/hub/issues/515#issuecomment-699928052 .

https://tfhub.dev/google/universal-sentence-encoder/4 The TF model claims to be TF2, however, the SavedModel states that it is TF 1.15. If it was TF2, I believe it wouldn't have an issue with tf.distribute.

https://colab.research.google.com/drive/1vgzBxzojLToHqR1XSGhmBepna1RfhXZG?usp=sharing Error colab notebook

Relevant code

# See https://colab.research.google.com/drive/1vgzBxzojLToHqR1XSGhmBepna1RfhXZG?usp=sharing

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_hub as hub

def make_use_4_model():
  inputs = keras.Input(shape=(), dtype=tf.dtypes.string, name="text_inputs")
  use = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4")
  outputs = use(inputs)
  return keras.Model(inputs, outputs, name="use")

strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) # Fake multiple GPUs
with strategy.scope():
  distributed_encoder_4 = hub.KerasLayer(
    handle="https://tfhub.dev/google/universal-sentence-encoder/4",
  )

Relevant log output

7 frames

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     53     ctx.ensure_initialized()
     54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55                                         inputs, attrs, num_outputs)
     56   except core._NotOkStatusException as e:
     57     if name is not None:

InvalidArgumentError: Graph execution error:

Detected at node 'Assert/Assert' defined at (most recent call last):
    File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
      "__main__", mod_spec)
    File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
      exec(code, run_globals)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
      self.io_loop.start()
    File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
      self._run_once()
    File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
      handle._run()
    File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
      self._context.run(self._callback, *self._args)
    File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
      handler_func(fileobj, events)
    File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 577, in _handle_events
      self._handle_recv()
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 606, in _handle_recv
      self._run_callback(callback, msg)
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 556, in _run_callback
      callback(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
      return self.dispatch_shell(stream, msg)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
      handler(stream, idents, msg)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
      user_expressions, allow_stdin)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
      return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
      interactivity=interactivity, compiler=compiler, result=result)
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
      if self.run_code(code, result):
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "<ipython-input-13-246cb39154ce>", line 3, in <module>
      handle="https://tfhub.dev/google/universal-sentence-encoder/4",
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/keras_layer.py", line 153, in __init__
      self._func = load_module(handle, tags, self._load_options)
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/keras_layer.py", line 449, in load_module
      return module_v2.load(handle, tags=tags, options=set_load_options)
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/module_v2.py", line 106, in load
      obj = tf.compat.v1.saved_model.load_v2(module_path, tags=tags)
Node: 'Assert/Assert'
assertion failed: [Trying to access a placeholder that is not supposed to be executed. This means you are executing a graph generated from the cross-replica context in an in-replica context.]
     [[{{node Assert/Assert}}]] [Op:__inference_restored_function_body_56858]

tensorflow_hub Version

0.12.0 (latest stable release)

TensorFlow Version

2.8 (latest stable release)

Other libraries

No response

Python Version

3.x

OS

Linux

akhorlin commented 2 years ago

https://tfhub.dev/google/universal-sentence-encoder/4 (/3) models are not supported under MirroredStrategy due the way the model(s) were published. The publisher of the model currently is not planning to republish the model, but rather recommends using

https://tfhub.dev/google/universal-sentence-encoder-cmlm/multilingual-base/1 https://tfhub.dev/google/universal-sentence-encoder-cmlm/multilingual-base-br/1

as illustrated in https://github.com/tensorflow/hub/issues/515#issuecomment-832059992