awsaf49 / gcvit-tf

Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer
MIT License
26 stars 5 forks source link

Network does not support arbitrary input size #18

Closed andreped closed 1 year ago

andreped commented 1 year ago

Is your feature request related to a problem? Please describe. Most, if not all, Keras network architectures support changing the input size. This is also true for the ViT-B/16 implementation found in the vit-keras project. When trying to do the same with this implementation, it fails.

Describe the solution you'd like I dont expect "any" input shape to work, but it would be great if it was possible to change it, which works well with other architecture backbone implementations out there.

Additional context A part of the error message can be seen here:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Exception encountered when calling layer 'attn' (type WindowAttention).

{{function_node __wrapped__Reshape_device_/job:localhost/replica:0/task:0/device:GPU:0}} Input to reshape is a tensor with 3211264 values, but the requested shape has 802816 [Op:Reshape]

Call arguments received by layer 'attn' (type WindowAttention):
  • inputs=['tf.Tensor(shape=(256, 49, 64), dtype=float16)', 'tf.Tensor(shape=(1, 14, 14, 64), dtype=float16)']
  • kwargs={'training': 'None'}

I tried to change from 224x224 to 1024x1024 input size.

awsaf49 commented 1 year ago

That is weird cuz I tried with different input shapes before and even trained models. I think this is either due to change of recent updates of TensorFlow version change.

awsaf49 commented 1 year ago

@andreped Hi, just set resize_query=True during model build, and it will support any input size =)

andreped commented 1 year ago

@andreped Hi, just set resize_query=True during model build, and it will support any input size =)

@awsaf49 Setting resize_query=True indeed resolves the issue! At least the training launched. Cheers :]

Sorry for the late reply. Been on vacation, and have not had the time to return this this issue until now.