octo-models / octo

Octo is a transformer-based robot policy trained on a diverse mix of 800k robot trajectories.
https://octo-models.github.io/
MIT License
787 stars 152 forks source link

shape mismatch #118

Open garspace2 opened 2 months ago

garspace2 commented 2 months ago

when i finetune model with a single custom dataset, and print the batch, the error is: ---> {'observation': {'image_primary': <tf.Tensor 'strided_slice_17:0' shape=(None, 64, 64, 3) dtype=uint8>, 'image_wrist': <tf.Tensor 'Repeat/Reshape_1:0' shape=(None,) dtype=string>, 'proprio': <tf.Tensor 'strided_slice_21:0' shape=(None, 7) dtype=float32>, 'timestep': <tf.Tensor 'range_1:0' shape=(None,) dtype=int32>}, 'task': {'language_instruction': <tf.Tensor 'strided_slice_16:0' shape=(None,) dtype=string>}, 'action': <tf.Tensor 'concat_1:0' shape=(None, 7) dtype=float32>, 'dataset_name': <tf.Tensor 'Repeat_1/Reshape_1:0' shape=(None,) dtype=string>} Traceback (most recent call last): File "/data/RND/dengjie/code/robot/octo/scripts/finetune_mydataset.py", line 415, in app.run(main) File "/home/user8/anaconda3/lib/python3.9/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/user8/anaconda3/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/data/RND/dengjie/code/robot/octo/scripts/finetune_mydataset.py", line 181, in main example_batch = next(train_data_iter) File "/home/user8/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4733, in next return nest.map_structure(to_numpy, next(self._iterator)) File "/home/user8/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 810, in next return self._next_internal() File "/home/user8/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 773, in _next_internal ret = gen_dataset_ops.iterator_get_next( File "/home/user8/anaconda3/lib/python3.9/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 3029, in iterator_get_next _ops.raise_from_not_ok_status(e, name) File "/home/user8/anaconda3/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 5883, in raise_from_not_ok_status raise core._status_to_exception(e) from None # pylint: disable=protected-access tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node wrappedIteratorGetNext_output_types_23device/job:localhost/replica:0/task:0/device:CPU:0}} condition [10], then [10,64,64,3], and else [10,64,64,3] must be broadcastable

when I use bridege_dataset and print the batch, the log is: ---> {'observation': {'image_primary': <tf.Tensor 'strided_slice_17:0' shape=(None,) dtype=string>, 'image_wrist': <tf.Tensor 'Repeat/Reshape_1:0' shape=(None,) dtype=string>, 'proprio': <tf.Tensor 'strided_slice_21:0' shape=(None, 7) dtype=float32>, 'timestep': <tf.Tensor 'range_1:0' shape=(None,) dtype=int32>}, 'task': {'language_instruction': <tf.Tensor 'strided_slice_16:0' shape=(None,) dtype=string>}, 'action': <tf.Tensor 'concat_1:0' shape=(None, 7) dtype=float32>, 'dataset_name': <tf.Tensor 'Repeat_1/Reshape_1:0' shape=(None,) dtype=string>}

how to solve it? thanks.