bfshi / scaling_on_scales

When do we not need larger vision models?
MIT License
321 stars 9 forks source link

About non-square input #6

Open githubiubiu opened 6 months ago

githubiubiu commented 6 months ago

Thank you for your fun and practical work! Have you considered non-square input, such as horizontal or vertical segmentation only? How to train.cls_token and pos_embed in this case will also become a problem.

bfshi commented 6 months ago

Yeah that's a good point. We can use the same way of image splitting for non-square images. For example, given a 224x448 image, at scale=2 (which is 448x896) I can split it into 2x4 crops with each size of 224x224. cls_token at each scale can be the average of cls_tokens of each crop, and the final cls_token is the concatenation of cls_tokens from all scales, as it is now. pos_embed can stay untouched since each crop of the non-squared image will be square, except for scale=1 where I guess we can interpolate the pos_embed to non-square.

We're considering to add this feature in the future. Stay tuned!

varadgunjal commented 5 months ago

Adding on to this thread since I'm interested in pursuing this further. Any idea on a rough ETA for this feature? I'd really like to test it out / help implement it. Also, what are your thoughts on doing a mapping to pre-specified resolutions (like AnyRes from LLaVa-1.6) to support non-square images of any size (not just multiples of 224)?

BTW, thanks for the great work. I tested out this approach and was glad to see positive trends in the output as reported in the paper.

bfshi commented 5 months ago

Hi @varadgunjal,

Great to hear that S2 shows some positive signals on your side!

Yes actually I already have the feature implemented. It's just currently I don't have enough resource to test it, e.g., on MLLMs. If you can help test it it would be great! I can push it to a dev branch and you can pull it and try it out. Btw, what's the downstream task you are testing on?

For supporting images of any resolutions other than just multiples of 224, I'm planning to do a split-and-stitch approach. For example, for a 336x336 image, I can get one 224 subimage at 0-224, and another 224 subimage at 112-336. We extract features for each of them, then stitch them together into a large feature map of 336x336 image. Since these two subimages have overlap, we can take the average feature at wherever they overlap. I can add this implementation too and maybe you can help test it out?

bfshi commented 5 months ago

Hi, I pushed my implementation to support images of any shape. Please check in branch dev_any_shape. Welcome to give it a test and looking forward to your feedback!