hyy0613 / RT-1

This is the completion of google's rt-1 project code and can run directly.
Apache License 2.0
22 stars 4 forks source link

when run evaluation.py, encountered a bug "jax.errors.TracerArrayConversionError" #4

Closed lin-whale closed 10 months ago

lin-whale commented 11 months ago

hello, i encountered a bug when run evaluation.py in directroy RT-1/. It seems tf.convert_to_tensor() can't be used for a jax traced object 'observation['rgb_sequence']'. So how can I figure out the bug? thanks for you help.

class BCJaxPyPolicy(py_policy.PyPolicy):
  """Runs inference with a BC policy."""

  def __init__(self, time_step_spec, action_spec, model,network_state,
               rng, params=None, action_statistics=None):
    super(BCJaxPyPolicy, self).__init__(time_step_spec, action_spec)
    self.model = model
    self.network_state = network_state
    self.rng = rng

    self._run_action_inference_jit = jax.jit(self._run_action_inference) #, static_argnums=(0,))

  def _run_action_inference(self, observation):
    # Add a batch dim.
    observation = jax.tree_map(lambda x: jnp.expand_dims(x, 0), observation)
    # print(observation)
    # 构造模型所需的输入
    observation_input = {}
    print("***************************************************************")
    print(observation['rgb_sequence'])

    observation_input['image'] = tf.convert_to_tensor(observation['rgb_sequence'])
    observation_input['natural_language_embedding'] = tf.convert_to_tensor(observation['instruction_tokenized_use'])

    output,_ = self.model(observation_input,step_type=None, network_state=self.network_state, training=False)
    print(output['action'].numpy()[0])
    print(observation['effector_translation'][:,:,:])
    action = output['action'].numpy()[0]
    return action

  def _action(self, time_step, policy_state=(), seed=0):
    observation = time_step.observation
    action = self._run_action_inference_jit(observation)
    return policy_step.PolicyStep(action=action)
Traceback (most recent call last):
  File "/home/aistar/RT-1/evaluation.py", line 156, in <module>
    app.run(main)
  File "/home/aistar/.local/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/aistar/.local/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/aistar/RT-1/evaluation.py", line 150, in main
    evaluate_checkpoint(
  File "/home/aistar/RT-1/evaluation.py", line 121, in evaluate_checkpoint
    policy_step = policy.action(ts, ())
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tf_agents/policies/py_policy.py", line 161, in action
    return self._action(time_step, policy_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aistar/RT-1/language_table/train/policy.py", line 62, in _action
    action = self._run_action_inference_jit(observation)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 253, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
                                                 ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 161, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
                                                 ^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 491, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
                                                      ^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 969, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
                                    ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 922, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aistar/RT-1/language_table/train/policy.py", line 51, in _run_action_inference
    observation_input['image'] = tf.convert_to_tensor(observation['rgb_sequence'])
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/opt/conda/lib/python3.11/site-packages/jax/_src/core.py", line 611, in __array__
    raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[1,6,256,456,3].
The error occurred while tracing the function _run_action_inference at /home/aistar/RT-1/language_table/train/policy.py:42 for jit. This concrete value was not available in Python because it depends on the value of the argument observation['rgb_sequence'].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/aistar/RT-1/evaluation.py", line 156, in <module>
    app.run(main)
  File "/home/aistar/.local/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/aistar/.local/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/aistar/RT-1/evaluation.py", line 150, in main
    evaluate_checkpoint(
  File "/home/aistar/RT-1/evaluation.py", line 121, in evaluate_checkpoint
    policy_step = policy.action(ts, ())
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tf_agents/policies/py_policy.py", line 161, in action
    return self._action(time_step, policy_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aistar/RT-1/language_table/train/policy.py", line 62, in _action
    action = self._run_action_inference_jit(observation)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aistar/RT-1/language_table/train/policy.py", line 51, in _run_action_inference
    observation_input['image'] = tf.convert_to_tensor(observation['rgb_sequence'])
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
    ^^^^^^^^^^^^^^^
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[1,6,256,456,3].
The error occurred while tracing the function _run_action_inference at /home/aistar/RT-1/language_table/train/policy.py:42 for jit. This concrete value was not available in Python because it depends on the value of the argument observation['rgb_sequence'].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
hyy0613 commented 11 months ago

This appears to be a jax error, and for some reason I recommend that you change this line of code in policy.py to action = self._run_action_inference(observation),this will not use jax acceleration, but it I think will solve your problem.

lin-whale commented 11 months ago

This appears to be a jax error, and for some reason I recommend that you change this line of code in policy.py to action = self._run_action_inference(observation),this will not use jax acceleration, but it I think will solve your problem.

This solved the issue! Appreciate for your help!