keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 330 forks source link

Stable Diffusion crashes on TPU (colab). #1460

Closed innat closed 1 year ago

innat commented 1 year ago

It crashes as long as the model.text_to_image function is called. Is there any known issue to run SD on tpu?

tf 2.11.0
keras-cv. 0.4.2

gist. sample-data.

logs: image

innat commented 1 year ago

cc. @LukeWood @miguelCalado

Same here.

innat commented 1 year ago

Error logs from tpu-vm

NotFoundError: Graph execution error:

Detected at node 'StatefulPartitionedCall' defined at (most recent call last):
    File "/usr/local/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/usr/local/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/usr/local/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/usr/local/lib/python3.8/site-packages/traitlets/config/application.py", line 1043, in launch_instance
      app.start()
    File "/usr/local/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 725, in start
      self.io_loop.start()
    File "/usr/local/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/usr/local/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/usr/local/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/usr/local/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/usr/local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 513, in dispatch_queue
      await self.process_one()
    File "/usr/local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 502, in process_one
      await dispatch(*args)
    File "/usr/local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 409, in dispatch_shell
      await result
    File "/usr/local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/usr/local/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "/usr/local/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/usr/local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2961, in run_cell
      result = self._run_cell(
    File "/usr/local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3016, in _run_cell
      result = runner(coro)
    File "/usr/local/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/usr/local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3221, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/usr/local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3400, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/usr/local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_14/3185780472.py", line 1, in <module>
      a = image_gen_model.text_to_image(
    File "/usr/local/lib/python3.8/site-packages/keras_cv/models/stable_diffusion/stable_diffusion.py", line 76, in text_to_image
      encoded_text = self.encode_text(prompt)
    File "/usr/local/lib/python3.8/site-packages/keras_cv/models/stable_diffusion/stable_diffusion.py", line 118, in encode_text
      context = self.text_encoder.predict_on_batch([phrase, self._get_pos_ids()])
    File "/usr/local/lib/python3.8/site-packages/keras/engine/training.py", line 2571, in predict_on_batch
      outputs = self.predict_function(iterator)
    File "/usr/local/lib/python3.8/site-packages/keras/engine/training.py", line 2137, in predict_function
      return step_function(self, iterator)
    File "/usr/local/lib/python3.8/site-packages/keras/engine/training.py", line 2123, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
Node: 'StatefulPartitionedCall'
could not find registered transfer manager for platform Host -- check target linkage
     [[{{node StatefulPartitionedCall}}]] [Op:__inference_predict_function_18563]
HoneyTyagii commented 1 year ago

This problem appears to be caused by the platform Host's lack of a transfer manager. The target linkage should be examined to determine whether it is properly configured. To examine whether there are any problems with how the data is being handled, it may also be useful to look at the code preceding this error.