tensorflow / datasets

TFDS is a collection of datasets ready to use with TensorFlow, Jax, ...
https://www.tensorflow.org/datasets
Apache License 2.0
4.29k stars 1.54k forks source link

FeatureConnector serialization broken for byte strings #237

Closed razorx89 closed 5 years ago

razorx89 commented 5 years ago

Short description I am currently writing a custom FeatureConnector for RLE encoded segmentations. In order to save some storage space I am going to use the byte encoding of the numpy arrays and tf.io.decode_raw in order to decode the information (since tf.train.Example only supports int64 for integer data types). However, this line breaks the serialization of the byte string: https://github.com/tensorflow/datasets/blob/c8c00f0843166775b44d3eba2e2171b44d2cb689/tensorflow_datasets/core/file_format_adapter.py#L397 One could argue, that this is a Tensorflow compat module bug, however, bytes returned from a FeatureConnector should directly be passed into the Protobuf Message instead of doing unnecessary conversions.

Environment information

Reproduction instructions

import numpy as np
import tensorflow as tf
tf.enable_eager_execution()
arr = np.zeros((64, 64), dtype=np.uint8)
bytes = arr.tobytes()
print(len(bytes))  # 4096
v = np.asarray(bytes).flatten()  # tensorflow_datasets/core/file_format_adapter.py#390
bytes = tf.compat.as_bytes(v[0])  # tensorflow_datasets/core/file_format_adapter.py#397
print(len(bytes)) # 0
remydubois commented 5 years ago

Exact same issue here.

razorx89 commented 5 years ago

I actually patched it locally and it is working fine. However, it might not be the correct way to check for byte strings:

    if isinstance(v, bytes):
      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[v]))
      continue
    v = np.array(v).flatten()  # Convert v into a 1-d array

https://github.com/tensorflow/datasets/blob/c8c00f0843166775b44d3eba2e2171b44d2cb689/tensorflow_datasets/core/file_format_adapter.py#L390

remydubois commented 5 years ago

@razorx89 fantastic, thanks

rsepassi commented 5 years ago

This is an interesting bug.

It seems that it's only an issue for byte strings that are all 0.

For example:

arr = np.ones((64, 64), dtype=np.uint8)
len(np.array(arr.tobytes()).flatten()[0])
# 4096
arr = np.zeros((64, 64), dtype=np.uint8)
len(np.array(arr.tobytes()).flatten()[0])
# 0
arr = np.array([0, 0, 0], dtype=np.uint8)
len(np.array(arr.tobytes()).flatten()[0])
# 0
arr = np.array([0, 0, 1], dtype=np.uint8)
len(np.array(arr.tobytes()).flatten()[0])
# 3
razorx89 commented 5 years ago

No, it's an issue with trailing zeros.

arr = np.array([0, 1, 0], dtype=np.uint8)
len(np.array(arr.tobytes()).flatten()[0])
# 2
arr = np.zeros((64, 64), dtype=np.uint8)
arr[32, 32] = 1
len(np.array(arr.tobytes()).flatten()[0])
# 2081
len(np.array(arr.tobytes()).tobytes())
# 4096
ParthS007 commented 5 years ago

@razorx89 I want to resolve this Bug and interested to work on it. How should I proceed with it? @rsepassi Can you please assign it to me.

Edit: Oh I see you have already made a PR resolving it.

razorx89 commented 5 years ago
import numpy as np
arr = np.array([0, 1, 0], dtype=np.uint8)
len(np.array(arr.tobytes()).flatten()[0])
# 2
len(np.array(arr.tobytes()).tobytes())
# 3
len(np.array(arr.tobytes()).flatten().tobytes())
# 3
len(np.array(arr.tobytes()).flatten()[0].tobytes())
# 2
np.array(arr.tobytes()).flatten()
# array([b'\x00\x01'], dtype='|S3')
np.array(arr.tobytes()).flatten()[0]
# b'\x00\x01'
np.array(arr.tobytes()).flatten()[0].dtype
# dtype('S2')
np.array(arr.tobytes()).flatten().dtype
# dtype('S3')

So actually the indexing after flatten() removes the numpy data type which preserves the trailing zeros. From the numpy docs:

'S', 'a' | zero-terminated bytes (not recommended)