zhen8838 / Circle-Loss

Tensorflow2 implementation of CircleLoss. Support class-level, sparse class-level, pair-wise labels
MIT License
108 stars 40 forks source link

Issue: The confliction between int664 and float32 #2

Closed kd610 closed 4 years ago

kd610 commented 4 years ago

Multiplication matrix by y_true needs to convert to float32 (tf.cast(y_true, tf.float32)). The input data type of y_true is tf.int32.

Refs: https://stackoverflow.com/questions/36210887/how-to-fix-matmul-op-has-type-float64-that-does-not-match-type-float32-typeerror

Those are the error which occurred.

/kaggle/working/circle_loss.py in call(self, y_true, y_pred)
 88     alpha_n = tf.nn.relu(tf.stop_gradient(y_pred) - self.O_n)
 89     # yapf: disable
---> 90     y_pred = (y_true * (alpha_p * (y_pred - self.Delta_p)) +
 91           (1-y_true) * (alpha_n * (y_pred - self.Delta_n))) * self.gamma
 92     # yapf: enable

 /opt/conda/lib/python3.6/site-packages/tensorflow_core/python/ops/math_ops.py in binary_op_wrapper(x, y)
     900     with ops.name_scope(None, op_name, [x, y]) as name:
     901       if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
 --> 902         return func(x, y, name=name)
     903       elif not isinstance(y, sparse_tensor.SparseTensor):
     904         try:

 /opt/conda/lib/python3.6/site-packages/tensorflow_core/python/ops/math_ops.py in _mul_dispatch(x, y, name)
    1199   is_tensor_y = isinstance(y, ops.Tensor)
    1200   if is_tensor_y:
 -> 1201     return gen_math_ops.mul(x, y, name=name)
    1202   else:
    1203     assert isinstance(y, sparse_tensor.SparseTensor)  # Case: Dense * Sparse.

 /opt/conda/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_math_ops.py in mul(x, y, name)
    6123   # Add nodes to the TensorFlow graph.
    6124   _, _, _op, _outputs = _op_def_library._apply_op_helper(
 -> 6125         "Mul", x=x, y=y, name=name)
    6126   _result = _outputs[:]
    6127   if _execute.must_record_gradient():

 /opt/conda/lib/python3.6/site-packages/tensorflow_core/python/framework/op_def_library.py in _apply_op_helper(op_type_name, name, **keywords)
     502                 "%s type %s of argument '%s'." %
     503                 (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
 --> 504                  inferred_from[input_arg.type_attr]))
     505 
     506         types = [values.dtype]

TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type int64 of argument 'x'.