Closed nkyventidis closed 2 years ago
Downgrading to TensorFlow 2.5.0 should help resolve the issue.
Thank you for your timely input.
Downgrading to TensorFlow 2.5.0 in Colab with pip unsuprisingly broke CUDA acceleration and caused tensorflow-probability to complain about version compatibility.
However, downgrading to tensorflow-probability v0.13 solved the issue without having to downgrade the current version of TensorFlow in Colab (v2.8.0):
!pip install tensorflow-probability==0.13
It would be nice having some insight about why the error occurs when the version changes though, if someone can help.
Hello there. I've tried running the vq_vae example in Colab. However in the "Codebook sampling" block, the line:
sampled = dist.sample()
throws the following error:
TypeError Traceback (most recent call last) in ()
3 x = pixel_cnn(inputs, training=False)
4 dist = tfp.distributions.Categorical(logits=x)
----> 5 sampled = dist.sample()
6 sampler = keras.Model(inputs, sampled)
15 frames /usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/distributions/distribution.py in sample(self, sample_shape, seed, name, kwargs) 1232 """ 1233 with self._name_and_control_scope(name): -> 1234 return self._call_sample_n(sample_shape, seed, kwargs) 1235 1236 def _call_sample_and_log_prob(self, sample_shape, seed, **kwargs):
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/distributions/distribution.py in _call_sample_n(self, sample_shape, seed, kwargs) 1210 sample_shape, 'sample_shape') 1211 samples = self._sample_n( -> 1212 n, seed=seed() if callable(seed) else seed, kwargs) 1213 samples = tf.nest.map_structure( 1214 lambda x: tf.reshape(x, ps.concat([sample_shape, ps.shape(x)[1:]], 0)),
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/distributions/categorical.py in _sample_n(self, n, seed) 246 return tf.reshape( 247 tf.transpose(draws), --> 248 shape=ps.concat([[n], self._batch_shape_tensor(logits=logits)], axis=0)) 249 250 def _cdf(self, k):
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/distributions/distribution.py in _batch_shape_tensor(self, parameter_kwargs) 1012 try: 1013 return batch_shape_lib.inferred_batch_shape_tensor( -> 1014 self, parameter_kwargs) 1015 except NotImplementedError: 1016 raise NotImplementedError('Cannot compute batch shape of distribution '
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/batch_shape_lib.py in inferred_batch_shape_tensor(batch_object, bijector_x_event_ndims, parameter_kwargs) 113 bijector_x_event_ndims=bijector_x_event_ndims, 114 require_static=False, --> 115 parameter_kwargs) 116 return functools.reduce(ps.broadcast_shape, tf.nest.flatten(batch_shapes), []) 117
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/batch_shape_lib.py in map_fn_over_parameters_with_event_ndims(batch_object, fn, bijector_x_event_ndims, require_static, **parameter_kwargs) 360 361 results[param_name] = nest.map_structure_up_to( --> 362 param, fn, param, param_event_ndims) 363 return results
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/nest.py in map_structure_up_to(shallowtree, func, *inputs, **kwargs) 1427 lambda , values: func(values), # Discards the path arg. 1428 *inputs, -> 1429 **kwargs) 1430 1431
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/nest.py in map_structure_with_tuple_paths_up_to(shallowtree, func, *inputs, **kwargs) 1524 for path, in _yield_flat_up_to(shallow_tree, inputs[0], is_nested_fn)) 1525 results = [ -> 1526 func(*args, *kwargs) for args in zip(flat_path_gen, flat_value_gen) 1527 ] 1528 return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/nest.py in(.0)
1524 for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_nested_fn))
1525 results = [
-> 1526 func(*args, *kwargs) for args in zip(flat_path_gen, flat_value_gen)
1527 ]
1528 return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/nest.py in(_, values)
1425 return map_structure_with_tuple_paths_up_to(
1426 shallowtree,
-> 1427 lambda , values: func(values), # Discards the path arg.
1428 inputs,
1429 **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/batch_shape_lib.py in get_batch_shape_tensor_part(x, event_ndims) 137 else: 138 base_shape = tf.shape(x) --> 139 return _truncate_shape_tensor(base_shape, event_ndims) 140 141
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/batch_shape_lib.py in _truncate_shape_tensor(shape, rightmost_ndims_to_truncate) 174 175 def _truncate_shape_tensor(shape, rightmost_ndims_to_truncate): --> 176 shape = ps.convert_to_shape_tensor(shape, dtype_hint=np.int32) 177 rightmost_ndims_to_truncate = ps.convert_to_shape_tensor( 178 rightmost_ndims_to_truncate, dtype_hint=np.int32)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs) 151 except Exception as e: 152 filtered_tb = _process_traceback_frames(e.traceback) --> 153 raise e.with_traceback(filtered_tb) from None 154 finally: 155 del filtered_tb
/usr/local/lib/python3.7/dist-packages/keras/layers/core/tf_op_layer.py in handle(self, op, args, kwargs) 105 isinstance(x, keras_tensor.KerasTensor) 106 for x in tf.nest.flatten([args, kwargs])): --> 107 return TFOpLambda(op)(*args, **kwargs) 108 else: 109 return self.NOT_SUPPORTED
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs) 65 except Exception as e: # pylint: disable=broad-except 66 filtered_tb = _process_traceback_frames(e.traceback) ---> 67 raise e.with_traceback(filtered_tb) from None 68 finally: 69 del filtered_tb
/usr/local/lib/python3.7/dist-packages/six.py in raise_from(value, from_value)
TypeError: Dimension value must be integer or None or have an index method, got value '<attribute 'shape' of 'numpy.generic' objects>' with type '<class 'getset_descriptor'>'
I have made no modifications to the code. Anyone know a way to fix this?
Thanks in advance for your precious time!
@sayakpaul