rishigami / Swin-Transformer-TF

Tensorflow implementation of Swin Transformer model.
Apache License 2.0
198 stars 46 forks source link

basic_layers/layer_with_weights-3/blocks/layer_with_weights-0/attn/relative_position_index/.ATTRIBUTES/VARIABLE_VALUE; expected dtype int32 does not equal original dtype int64 #4

Closed zerobest closed 2 years ago

rishigami commented 2 years ago

Hi, could you explain the problem in more detail? FYI. simple example in Colab. https://colab.research.google.com/drive/1v1yrlaQUDluwJvBsgOBfJPlmmDW7-fYk?usp=sharing

jangjiun commented 2 years ago

import tensorflow as tf from swintransformer import SwinTransformer

model = tf.keras.Sequential([ tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]), SwinTransformer('swin_tiny_224', include_top=False, pretrained=True), tf.keras.layers.Dense(NUM_CLASSES, activation='softmax') ])

tf can't load pre trained model。this step is errro

rishigami commented 2 years ago

@jangjiun What is your tf & np version?

# Tensorflow version
print(f"TensorFlow Version: {tf.__version__}")
# Numpy version
print(f"Numpy Version: {np.__version__}")
rishigami commented 2 years ago

@jangjiun Now relative_position_index is explicitly converted to int64 type. Could you try again?

jangjiun commented 2 years ago

TensorFlow Version: 2.4.1 Numpy Version: 1.20.3

the model can load now ,but have a new mistake:

TypeError: Failed to convert object of type <class 'tuple'> to Tensor. Contents: (None, 1, 1). Consider casting elements to a supported type.

jangjiun commented 2 years ago

C:\ProgramData\Anaconda3\envs\tf2x\lib\site-packages\tensorflow\python\keras\engine\training.py:805 train_function return step_function(self, iterator) F:\txpj_2021\swin_transformer\swintransformer\model.py:422 call x = self.forward_features(x) F:\txpj_2021\swin_transformer\swintransformer\model.py:416 forward_features x = self.basic_layers(x) F:\txpj_2021\swin_transformer\swintransformer\model.py:307 call x = self.blocks(x) F:\txpj_2021\swin_transformer\swintransformer\model.py:242 call x = shortcut + self.drop_path(x) F:\txpj_2021\swin_transformer\swintransformer\model.py:144 call return drop_path(x, self.drop_prob, training) F:\txpj_2021\swin_transformer\swintransformer\model.py:131 drop_path * random_tensor += tf.random.uniform(shape, dtype=inputs.dtype) C:\ProgramData\Anaconda3\envs\tf2x\lib\site-packages\tensorflow\python\util\dispatch.py:201 wrapper return target(*args, *kwargs) C:\ProgramData\Anaconda3\envs\tf2x\lib\site-packages\tensorflow\python\ops\random_ops.py:289 random_uniform shape = tensor_util.shape_tensor(shape) C:\ProgramData\Anaconda3\envs\tf2x\lib\site-packages\tensorflow\python\framework\tensor_util.py:1035 shape_tensor return ops.convert_to_tensor(shape, dtype=dtype, name="shape") C:\ProgramData\Anaconda3\envs\tf2x\lib\site-packages\tensorflow\python\profiler\trace.py:163 wrapped return func(args, kwargs) C:\ProgramData\Anaconda3\envs\tf2x\lib\site-packages\tensorflow\python\framework\ops.py:1540 convert_to_tensor ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) C:\ProgramData\Anaconda3\envs\tf2x\lib\site-packages\tensorflow\python\framework\constant_op.py:339 _constant_tensor_conversion_function return constant(v, dtype=dtype, name=name) C:\ProgramData\Anaconda3\envs\tf2x\lib\site-packages\tensorflow\python\framework\constant_op.py:264 constant return _constant_impl(value, dtype, shape, name, verify_shape=False, C:\ProgramData\Anaconda3\envs\tf2x\lib\site-packages\tensorflow\python\framework\constant_op.py:281 _constant_impl tensor_util.make_tensor_proto( C:\ProgramData\Anaconda3\envs\tf2x\lib\site-packages\tensorflow\python\framework\tensor_util.py:551 make_tensor_proto raise TypeError("Failed to convert object of type %s to Tensor. "

TypeError: Failed to convert object of type <class 'tuple'> to Tensor. Contents: (None, 1, 1). Consider casting elements to a supported type.
jangjiun commented 2 years ago

OK,I have solved this problem;because the model.py x.shape change to tf.shape(x)