AlexKuhnle / ShapeWorld

MIT License
58 stars 18 forks source link

Alternatives in Records breaks batch_records #7

Closed tomsherborne closed 6 years ago

tomsherborne commented 6 years ago

Hi Alex,

I'm playing around with your tf_util interface for loading batches of data and I find that if I generate a small set as:

python3 generate.py -d some_dir -a tar:bzip2 -t agreement -n oneshape -s 5,1,1 -i 100 -M -T  --config-values --correct_ratio 1.0 --captions_per_instance 5

Then when running the example data loading:

dataset = Dataset.create(dtype='agreement', name='oneshape', config='some_dir')
generated = tf_util.batch_records(dataset=dataset, mode='train', batch_size=128)

I get the error

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-2-58fc10a45e5c> in <module>()
      1 dataset = Dataset.create(dtype='agreement', name='oneshape', config='some_dir')
----> 2 generated = tf_util.batch_records(dataset=dataset, mode='train', batch_size=128)

~/code/acs/ShapeWorld/shapeworld/tf_util.py in batch_records(dataset, mode, batch_size)
     77                 batch[value_name] = tf.clip_by_value(t=(batch[value_name] + noise), clip_value_min=0.0, clip_value_max=1.0)
     78             elif value_type == 'int' or value_type == 'vector(int)' or value_type in dataset.vocabularies:
---> 79                 batch[value_name] = tf.cast(x=batch[value_name], dtype=tf.int32)
     80         return batch
     81 

KeyError: 'alternatives'

I find this is because of the loop on line 75 that iterates through the key, value pairs from the dataset.values dict but the batch dict no longer contains alternatives due to the call of records.pop('alternatives'). I added a breaking condition to fix this as:

for value_name, value_type in dataset.values.items():
     if value_name=='alternatives':
         break

I have this as a PR from my fork that I can submit but I'm finding a larger problem with loading data in this way as the call to evaluate a batch:

with tf.Session() as sess:
    batch = sess.run(generated)

hangs for an unreasonably long time. I've not measured exactly because it might never recover but it appears to demand at least 10 minutes of setup time whereas the data loading modules I've written take almost no time to evaluate a batch. I'm not sure where the issue is but I'm happy to look somewhere if you can point me in the right direction if you find the same issue trying to evaluate a batch.

[Using Mac OSX 10.13.2, Python 3.5.4, Tensorflow 1.5.0]

AlexKuhnle commented 6 years ago

You're right, I will fix that with the next commit (just iterate over batch instead of dataset.values). Regarding the other problem: You basically try to just retrieve the TensorFlow records, right? I suspect the TF queue initialization might be missing:

session = tf.Session()
coordinator = tf.train.Coordinator()
queue_threads = tf.train.start_queue_runners(sess=session, coord=coordinator)
[...]
coordinator.request_stop()
coordinator.join(threads=queue_threads)
session.close()

I realise that this is somewhere hidden in my models code, and I should add a note to the readme and the Python file, if that indeed solves the problem. So please let me know. :-)