googlecolab / colabtools

Python libraries for Google Colaboratory
Apache License 2.0
2.19k stars 720 forks source link

Fit under TPU strategy fails on cardinality #1057

Open jonyvp opened 4 years ago

jonyvp commented 4 years ago

When running the fit function on a model after compiling the model with the TPUStrategy as in #1056 , the cardinality is implicitely calculated. A TFRecordDataset is used as Dataset for the fit function. The error below is raised:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-7-db3ea26a4a48> in <module>()
----> 1 mh.start_training()

8 frames
/content/modelling/NN/ModelHandler.py in start_training(self)
     73                 steps_per_epoch=self.dataloaders[0].amount_of_batches,
     74                 validation_steps=self.dataloaders[1].amount_of_batches,
---> 75                 verbose=1
     76                 )
     77 

/content/modelling/NN/Model.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    312                              validation_steps=validation_steps, validation_freq=validation_freq,
    313                              max_queue_size=max_queue_size, workers=workers,
--> 314                              use_multiprocessing=use_multiprocessing, **kwargs)
    315         return out
    316 

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    817         max_queue_size=max_queue_size,
    818         workers=workers,
--> 819         use_multiprocessing=use_multiprocessing)
    820 
    821   def evaluate(self,

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    258           steps_per_epoch,
    259           steps_name='steps_per_epoch',
--> 260           epochs=0)
    261 
    262       steps_per_epoch = (

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/engine/training_utils.py in infer_steps_for_dataset(model, dataset, steps, epochs, steps_name)
   1747     return None
   1748 
-> 1749   size = K.get_value(cardinality.cardinality(dataset))
   1750   if size == cardinality.INFINITE and steps is None:
   1751     raise ValueError('When passing an infinitely repeating dataset, you '

/tensorflow-2.1.0/python3.6/tensorflow_core/python/data/experimental/ops/cardinality.py in cardinality(dataset)
     49   """
     50 
---> 51   return ged_ops.dataset_cardinality(dataset._variant_tensor)  # pylint: disable=protected-access

/tensorflow-2.1.0/python3.6/tensorflow_core/python/ops/gen_experimental_dataset_ops.py in dataset_cardinality(input_dataset, name)
    663         pass  # Add nodes to the TensorFlow graph.
    664     except _core._NotOkStatusException as e:
--> 665       _ops.raise_from_not_ok_status(e, name)
    666   # Add nodes to the TensorFlow graph.
    667   _, _, _op, _outputs = _op_def_library._apply_op_helper(

/tensorflow-2.1.0/python3.6/tensorflow_core/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6604   message = e.message + (" name: " + name if name is not None else "")
   6605   # pylint: disable=protected-access
-> 6606   six.raise_from(core._status_to_exception(e.code, message), None)
   6607   # pylint: enable=protected-access
   6608 

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: Unable to parse tensor proto [Op:DatasetCardinality]

When calculating the cardinality without calling the TPUStrategy context as shown below, a value of -2 is yielded:

import tensorflow.keras.backend as K
K.get_value(tf.data.experimental.cardinality(ds))
tnovikoff commented 4 years ago

This issue appears to be a TPU issue rather than a Colab issue