DagsHub / streaming-client

MIT License
2 stars 0 forks source link

Fixed Tensor Stacking bug on TensorFlow Dataloader #35

Closed jinensetpal closed 1 year ago

jinensetpal commented 1 year ago

Error:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-21-0a8375c685c1> in <cell line: 5>()
      3 dataloader = query_result.as_ml_dataloader(flavor='tensorflow', tensorizers=['image', lambda x: tf.convert_to_tensor(x, tf.uint8)], metadata_columns=['int_field'])
      4 
----> 5 for X, y in dataloader:
      6   print(X, y)
      7   # some training here

4 frames
/usr/local/lib/python3.10/dist-packages/keras/utils/data_utils.py in __iter__(self)
    564     def __iter__(self):
    565         """Create a generator that iterate over the Sequence."""
--> 566         for item in (self[i] for i in range(len(self))):
    567             yield item
    568 

/usr/local/lib/python3.10/dist-packages/keras/utils/data_utils.py in <genexpr>(.0)
    564     def __iter__(self):
    565         """Create a generator that iterate over the Sequence."""
--> 566         for item in (self[i] for i in range(len(self))):
    567             yield item
    568 

/usr/local/lib/python3.10/dist-packages/dagshub/data_engine/client/loaders/tf.py in __getitem__(self, index)
     51         for index in indices:
     52             X.append(self.dataset.__getitem__(index))
---> 53         return tf.stack(X)
     54 
     55     def on_epoch_end(self) -> None:

/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   7260 def raise_from_not_ok_status(e, name):
   7261   e.message += (" name: " + name if name is not None else "")
-> 7262   raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   7263 
   7264 

InvalidArgumentError: {{function_node __wrapped__Pack_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Shapes of all inputs must match: values[0].shape = [290,611,3] != values[1].shape = [] [Op:Pack] name: 0

The dataloader attempted to combine all the tensors into one giant tensor, which is why it expected shapes to be the same. Returning just a list like torch fixes this.