When running "train.py", I encountered the following error:
Traceback (most recent call last):
File "C:/Users/96492/Anaconda3/Lib/tensorflow-wavenet-master/train.py", line 337, in
main()
File "C:/Users/96492/Anaconda3/Lib/tensorflow-wavenet-master/train.py", line 255, in main
l2_regularization_strength=args.l2_regularization_strength)
File "C:\Users\96492\Anaconda3\Lib\tensorflow-wavenet-master\wavenet\model.py", line 620, in loss
encoded = self._one_hot(encoded_input)
File "C:\Users\96492\Anaconda3\Lib\tensorflow-wavenet-master\wavenet\model.py", line 515, in _one_hot
dtype=tf.float32)
File "C:\Users\96492\venv\lib\site-packages\tensorflow_core\python\util\dispatch.py", line 180, in wrapper
return target(*args, **kwargs)
File "C:\Users\96492\venv\lib\site-packages\tensorflow_core\python\ops\array_ops.py", line 3516, in one_hot
name)
File "C:\Users\96492\venv\lib\site-packages\tensorflow_core\python\ops\gen_array_ops.py", line 6137, in one_hot
off_value=off_value, axis=axis, name=name)
File "C:\Users\96492\venv\lib\site-packages\tensorflow_core\python\framework\op_def_library.py", line 632, in _apply_op_helper
param_name=input_name)
File "C:\Users\96492\venv\lib\site-packages\tensorflow_core\python\framework\op_def_library.py", line 61, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'indices' has DataType float32 not in list of allowed values: uint8, int32, int64
I tried to solve the problem from line 515 of "model.py", and I think there is an error in the type of input_batch, then I want to replace it with tf.cast(input_batch, tf.int32).
Although the code can be run next, I wonder if such a modification is correct?
When running "train.py", I encountered the following error:
Traceback (most recent call last): File "C:/Users/96492/Anaconda3/Lib/tensorflow-wavenet-master/train.py", line 337, in
main()
File "C:/Users/96492/Anaconda3/Lib/tensorflow-wavenet-master/train.py", line 255, in main
l2_regularization_strength=args.l2_regularization_strength)
File "C:\Users\96492\Anaconda3\Lib\tensorflow-wavenet-master\wavenet\model.py", line 620, in loss
encoded = self._one_hot(encoded_input)
File "C:\Users\96492\Anaconda3\Lib\tensorflow-wavenet-master\wavenet\model.py", line 515, in _one_hot
dtype=tf.float32)
File "C:\Users\96492\venv\lib\site-packages\tensorflow_core\python\util\dispatch.py", line 180, in wrapper
return target(*args, **kwargs)
File "C:\Users\96492\venv\lib\site-packages\tensorflow_core\python\ops\array_ops.py", line 3516, in one_hot
name)
File "C:\Users\96492\venv\lib\site-packages\tensorflow_core\python\ops\gen_array_ops.py", line 6137, in one_hot
off_value=off_value, axis=axis, name=name)
File "C:\Users\96492\venv\lib\site-packages\tensorflow_core\python\framework\op_def_library.py", line 632, in _apply_op_helper
param_name=input_name)
File "C:\Users\96492\venv\lib\site-packages\tensorflow_core\python\framework\op_def_library.py", line 61, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'indices' has DataType float32 not in list of allowed values: uint8, int32, int64
encoded = tf.one_hot( input_batch, depth=self.quantization_channels, dtype=tf.float32)
I tried to solve the problem from line 515 of "model.py", and I think there is an error in the type of
input_batch
, then I want to replace it withtf.cast(input_batch, tf.int32)
. Although the code can be run next, I wonder if such a modification is correct?