gmalivenko / pytorch2keras

PyTorch to Keras model convertor
https://pytorch2keras.readthedocs.io/en/latest/
MIT License
857 stars 143 forks source link

Type error while converting PyTorch model - Input 'y' of 'Mul' Op has type float32 that does not match type int32 of argument 'x' #126

Open jiteshm17 opened 3 years ago

jiteshm17 commented 3 years ago

I am trying to convert a Pytorch model to Keras using the Pytorch2Keras library. I'm doing this on a pre-trained colorization model here.

Just to see if my model and the PyTorch weights are fine, I tried converting it into onnx format first. I am able to export it to onnx but not able to convert the onnx model to keras or tensorflow. I get the same error that I get while using the Pytorch2Keras library

The traceback is

ValueError                                Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/math_ops.py in binary_op_wrapper(x, y)
   1174             r_op = getattr(y, "__r%s__" % op_name)
-> 1175             out = r_op(x)
   1176             if out is NotImplemented:

19 frames
ValueError: Tensor conversion requested dtype int32 for Tensor with dtype float32: <tf.Tensor 'Cast:0' shape=() dtype=float32>

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(op_type_name, name, **keywords)
    556                 "%s type %s of argument '%s'." %
    557                 (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
--> 558                  inferred_from[input_arg.type_attr]))
    559 
    560         types = [values.dtype]

TypeError: Input 'y' of 'Mul' Op has type int32 that does not match type float32 of argument 'x'

I looked into this post and from what I could understand, it has something to do with the newer version of tensorflow.

Is there any way I could find out what layer is causing the issue?

The full code on Google Colab can be found here

Any help would be appreciated. Thanks!