rishigami / Swin-Transformer-TF

Tensorflow implementation of Swin Transformer model.
Apache License 2.0
198 stars 46 forks source link

Invalid argument #12

Closed AliKayhanAtay closed 2 years ago

AliKayhanAtay commented 2 years ago

this is my basic model


with tpu_strategy.scope():
    model = tf.keras.Sequential([
                        tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(data, mode="torch"), 
                                                            input_shape=[224,224, 3]),
                        SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
                        tf.keras.layers.Dense(1, activation='sigmoid')
                                        ])

model.compile(loss = tf.keras.losses.BinaryCrossentropy(),
                          optimizer = tf.keras.optimizers.Adam(learning_rate=cfg['LEARNING_RATE']),
                          metrics   = RMSE)

I am getting this error,

(3) Invalid argument: {{function_node __inference_train_function_705020}} Reshape's input dynamic dimension is decomposed into multiple output dynamic dimensions, but the constraint is ambiguous and XLA can't infer the output dimension %reshape.12202 = f32[256,144,576]{2,1,0} reshape(f32[36864,576]{1,0} %transpose.12194), metadata={op_type="Reshape" op_name="sequential_40/swin_large_384/sequential_39/basic_layer_28/sequential_35/swin_transformer_block_169/window_attention_169/layers0/blocks1/attn/qkv/Tensordot"}. [[{{node TPUReplicate/_compile/_17658394825749957328/_4}}]] [[tpu_compile_succeeded_assert/_11424487196827204192/_5/_209]]

rishigami commented 2 years ago

When you use swin transformer with TPU, you need to specify drop_remainder=True option. Below is a code snippet from kaggle notebook.

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset