keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
976 stars 318 forks source link

Implementation of Focal Cutout or Focal Masking #458

Open sayakpaul opened 2 years ago

sayakpaul commented 2 years ago

Not sure if any other work has implemented and investigated this approach of Focal Masking before but [1] combines Focal Masking and Random Masking to improve self-supervised pre-training for learning visual features.

The idea of Focal Masking (and comparison to Random Masking) is demonstrated in the figure below:

image

Source: [1]

Masking strategies have been super useful for NLP (BERT). They have shown good promise for Vision too (Masked Autoencoders, BeiT, SimMIM, Data2Vec, etc.). I believe it'd be useful for KerasCV users because it would allow them to pre-train models using different masking strategies and investigate their impact.

Random Masking is quite similar to Cutout (TensorFlow Addons module) and that is why I used the term Focal Cutout.

References

[1] Masked Siamese Networks: https://arxiv.org/abs/2204.07141

/cc: @LukeWood

LukeWood commented 2 years ago

Thanks Sayak! Looks like a great tool. I'd love to be able to reproduce masked autoencoders using KerasCV components. Thanks!

AdityaKane2001 commented 2 years ago

@LukeWood

My intuition regarding this augmentation was that we would take a few consecutive patches and black out the rest of the image. However, in their official implementation (here) they have simply used inception crop and other augmentations. I'm not sure whether that is what we want.

/cc @sayakpaul

ariG23498 commented 2 years ago

Hey folks,

I took a look at the paper and the code. These are my thoughts:

I would love to chime in for this augmentation layer if we ever wanted to build it in keras-cv.

sayakpaul commented 2 years ago

Thanks for your interest and for pointing that out, @ariG23498!

@AdityaKane2001 is actually implementing this as a part of GSoC, so I'm not too sure about the scope of taking another contributor on this one.

ariG23498 commented 2 years ago

I understand! Can we have a label for GSoC so that we point out which issues are already taken?

@AdityaKane2001 all the very best for your contributions!

sayakpaul commented 2 years ago

@LukeWood let's create a label for GSoC-22-related stuff to better separate them?

AdityaKane2001 commented 2 years ago

@ariG23498

Thanks for the insight! I just have a small question. It looks to me that patches are getting masked randomly, and not in the order mentioned in the paper. As you said, the question of RandomResizedCrop also remains, since they have used it with a very low area factor (0.05 to 0.3) which seems quite unusual.

Code snippet for reference:

        if patch_drop > 0:
            patch_keep = 1. - patch_drop
            T_H = int(np.floor((x.shape[1]-1)*patch_keep))
            perm = 1 + torch.randperm(x.shape[1]-1)[:T_H]  # keep class token
            idx = torch.cat([torch.zeros(1, dtype=perm.dtype, device=perm.device), perm])
            x = x[:, idx, :]
sayakpaul commented 2 years ago

I guess it's best to clarify this with the authors of the paper.

@imisra @MidoAssran could you please help us with the doubts mentioned in https://github.com/keras-team/keras-cv/issues/458#issuecomment-1155090357 and https://github.com/keras-team/keras-cv/issues/458#issuecomment-1155985462?

For context, @AdityaKane2001 is trying to implement focal masking (as done in MSN) in KerasCV for allowing the users to experiment with different masking strategies and study their impact in pre-training schemes.

Thanks!

MidoAssran commented 2 years ago

Hi @AdityaKane2001, focal masking is just extracting a small-crop from an image-view. An image-view is created via random-data augmentations of an image. For efficiency, you can do both simultaneously in the data-loader with the RandomResizedCrop function in Pytorch.

Key points: Notice that the crop-scale is very small (0.05, 0.3); meaning we are extracting crops that range between 5% to 30% of the total image size, and then resize these to 96x96 pixels for efficient batch processing (so that all the focal crops can be processed in the same forward pass). This is simply equivalent to called RandomResizedCrop with the aforementioned scale and crop-size!

On random masking: Random masking on the other hand cannot be implemented in the same way, since it corresponds to dropping non-contiguous patches. Therefore, after creating an image-view, the random masking is executed in the encoder by randomly dropping input patches.

This code, here, ensures that patch-dropping only happens to the random mask views, and not the focal views (which were already masked in the data-loader).

AdityaKane2001 commented 2 years ago

@MidoAssran

Thanks for the clarification! I just have one question. The illustration in the paper suggests involvement of patching, and contiguous patches in a grid are retained while dropping the rest of the image in the case of focal masking. However, the code as well as the procedure you mentioned does not take this into consideration. Could you please share your thoughts on this?

MidoAssran commented 2 years ago

@AdityaKane2001

If you wish to explicitly separate the image-view generation from focal masking for conceptual reasons, you can create the image-views for the focal crops using RandomResizeCrop to a size of 224x224 pixels with a scale range of approximately [0.1, 0.7] (i.e., just multiply the current range by 224/96), and then randomly keep a block of patches (6x6 block for the /16 networks), and that should give you the same behaviour.

However, one can simply combine those two steps from an implementation perspective to reproduce the same behaviour while improving efficiency.

AdityaKane2001 commented 2 years ago

@MidoAssran

Thanks a lot for the clarification. It is clear now.

github-actions[bot] commented 5 months ago

This issue is stale because it has been open for 180 days with no activity. It will be closed if no further activity occurs. Thank you.