keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 330 forks source link

Add MaxViT model #912

Closed innat closed 1 year ago

innat commented 2 years ago

Short Description

FeuQq4CUoAkGkMe

Multi-Axis Vision Transformer: MaxViT is a family of hybrid (CNN + ViT) image classification models, that achieves better performances across the board for both parameter and FLOPs efficiency than both SoTA ConvNets and Transformers.

Papers

https://arxiv.org/abs/2204.01697

Existing Implementations

Official Implementation: Goolge, TensorFlow 2 (Keras). https://github.com/google-research/maxvit

cc. @Yinxiaoli @vztu

bhack commented 2 years ago

Quite related to https://github.com/keras-team/keras-cv/issues/911

ayulockin commented 2 years ago

Hey all I can work on this. :)

DavidLandup0 commented 2 years ago

I'd gladly also port this from the official repo to here. :) If someone could assign me to it, I'd get to it as soon as the Dice and Jaccard coefficients are done.

tanzhenyu commented 2 years ago

Hey all I can work on this. :)

@ayulockin @innat Ideally we would like to have SwinTransformer first: https://github.com/keras-team/keras-cv/issues/671

innat commented 2 years ago

@tanzhenyu I think vanila ViT should be first, it's like VGG for transformer 😄

In the mean time, I think it's also ok to start working on the basic component like window partition, grid attention, trail-dense etc. cc @ayulockin @DavidLandup0

tanzhenyu commented 2 years ago

@tanzhenyu I think vanila ViT should be first, it's like VGG for transformer 😄

In the mean time, I think it's also ok to start working on the basic component like window partition, grid attention, trail-dense etc. cc @ayulockin @DavidLandup0

https://github.com/keras-team/keras-cv/issues/668

DavidLandup0 commented 2 years ago

Creating a pull request later today with layers for patching, mlp heads, linear projections, etc. We can use those to build a ViT and then extend it to Swin and other transformers for vision. A rough draft for ViT will be coming in with the basic layers. Would you prefer a PR for components, and then a PR for ViT on a different branch instead? @tanzhenyu @innat

bhack commented 2 years ago

Not that we need to do the same but at the same time I will take a look to also at the modularization organized in the quite popular Huggingface Transformers API https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_tf_vit.py

tanzhenyu commented 2 years ago

Creating a pull request later today with layers for patching, mlp heads, linear projections, etc. We can use those to build a ViT and then extend it to Swin and other transformers for vision. A rough draft for ViT will be coming in with the basic layers. Would you prefer a PR for components, and then a PR for ViT on a different branch instead? @tanzhenyu @innat

Given we don't anticipate the need to expose components such as linear projections as public APIs, either creating a single PR or multiple PRs sounds good to me. In this case, it really depends if you want it to be re-used at Swin and others. Modularity is the key. If that's what you want, would you mind coming up with a basic design to show how those components can fit into other models?

DavidLandup0 commented 2 years ago

Sure! I'm packaging them into one PR as a draft overview, just to check whether the general structure is okay. It'd be unwise to work more on it if major changes need to be done. I'm testing out a rough idea and will push it in later for a cursory look/review :)

The idea was to build blocks that we can reuse for most transformer-based models. Currently, building a ViT with it looks like this:

inputs = utils.parse_model_inputs(input_shape, input_tensor)
 x = inputs

 if include_rescaling:
     x = layers.Rescaling(1 / 255.0)(x)

 patches = keras_cv.layers.Patching(patch_size)(x)
 encoded_patches = keras_cv.transformers.PatchEncoder(num_patches, project_dim)(patches)

    for _ in range(transformer_layer_num):
        x = keras_cv.transformers.TransformerEncoder()(encoded_patches)

 representation = layers.LayerNormalization(epsilon=1e-6)(x)
 representation = layers.Flatten()(representation)
 representation = layers.Dropout(0.5)(representation)

 features =  mlp_ffn(representation, hidden_units=head_units, dropout_rate=0.5)
 logits = layers.Dense(num_classes)(features)
 model = keras.Model(inputs=inputs, outputs=logits)
ayulockin commented 2 years ago

In the mean time, I think it's also ok to start working on the basic component like window partition, grid attention, trail-dense etc. cc @ayulockin @DavidLandup0

I think adding basic components as you mentioned should be the way to go. KerasCV's aim is to provide components for industrial adaption of research. I think instead of focusing on models (ViT, Swin, etc) we should scope the transformers for vision such that we can build fundamentals blocks.

bhack commented 2 years ago

I think adding basic components as you mentioned should be the way to go. KerasCV's aim is to provide components for industrial adaption of research. I think instead of focusing on models (ViT, Swin, etc) we should scope the transformers for vision such that we can build fundamentals blocks.

It is why I've suggest to explore Huggingface transformer modules.

Probably it is not the best modularization that we could achieve but at least they have already accumulate a quite relevant list of transformer archs on the library.

I don't know if its is production level or not but at least it is partially validated by the number of models.

tanzhenyu commented 2 years ago

Sure! I'm packaging them into one PR as a draft overview, just to check whether the general structure is okay. It'd be unwise to work more on it if major changes need to be done. I'm testing out a rough idea and will push it in ~30min for a cursory look/review :)

The idea was to build blocks that we can reuse for most transformer-based models. Currently, building a ViT with it looks like this:

inputs = utils.parse_model_inputs(input_shape, input_tensor)
 x = inputs

 if include_rescaling:
     x = layers.Rescaling(1 / 255.0)(x)

 patches = keras_cv.layers.Patching(patch_size)(x)
 encoded_patches = keras_cv.transformers.PatchEncoder(num_patches, project_dim)(patches)

    for _ in range(transformer_layer_num):
        x = keras_cv.transformers.TransformerEncoder()(encoded_patches)

 representation = layers.LayerNormalization(epsilon=1e-6)(x)
 representation = layers.Flatten()(representation)
 representation = layers.Dropout(0.5)(representation)

 features =  mlp_ffn(representation, hidden_units=head_units, dropout_rate=0.5)
 logits = layers.Dense(num_classes)(features)
 model = keras.Model(inputs=inputs, outputs=logits)

This seems to be concise. My only question here is whether the TransformerEncoder is implemented in a different way than "normal" transformer encoders, or the same way? We might want to consider bring this to core Keras instead of KerasCV if it's the same way.

And yes I agree with @bhack and @ayulockin that we should take a look at HF's implementation and make sure we're providing enough modularization.

bhack commented 2 years ago

We might want to consider bring this to core Keras instead of KerasCV if it's the same way

Yes this is another very important point already discussed to minimize (future?) duplications with Keras-nlp

DavidLandup0 commented 1 year ago

As ViTs are finished - I'll be working on this one now ;) If anyone wants to collab, let me know. (@ayulockin wanted to work on this a while back)

ayulockin commented 1 year ago

Hey, @DavidLandup0, I would love to collaborate on this with you. :) I was waiting for the ViT to be added so I could build on top of it from a design point of view. Since you have worked on it, collaborating with you would be a great learning experience. :)

innat commented 1 year ago

@ayulockin just to inform, MAXIM is welcomed too. Most of the official code (jax) + weight was ported to keras, here.

DavidLandup0 commented 1 year ago

If nobody else signs up for it by the time MaxViT is done, I'd gladly hop onto MAXIM too :)

DavidLandup0 commented 1 year ago

Since MaxViT uses MBConvs, which we have in EfficientNets, and which originated in MobileNets - we'll have three architectures reusing them same blocks. Additionally, having them as a layer would let users try to build networks with them themselves for edge/mobile applications.

I think we should have MBConv as a standalone layer.

Can I separate it into a layer and refactor EfficientNets in preparation for MaxViT? @tanzhenyu @LukeWood @bhack

IMvision12 commented 1 year ago

I can work on MAXIM !!

tanzhenyu commented 1 year ago

Since MaxViT uses MBConvs, which we have in EfficientNets, and which originated in MobileNets - we'll have three architectures reusing them same blocks. Additionally, having them as a layer would let users try to build networks with them themselves for edge/mobile applications.

I think we should have MBConv as a standalone layer.

Can I separate it into a layer and refactor EfficientNets in preparation for MaxViT? @tanzhenyu @LukeWood @bhack

Yep, it'd be great to reuse both MBConv and SE

DavidLandup0 commented 1 year ago

Done in new PR :)

1146

tanzhenyu commented 1 year ago

I can work on MAXIM !!

Go ahead!

ayulockin commented 1 year ago

Here is a quick update on the work done so far:

Work done in collaboration with @DavidLandup0 :)

We have almost all the components - WindowPartition, UnWindowPartition, GridPartition, UnGridPartition and RelativeMultiHeadAttention done.

We have stacked them together to build a barebone MaxViTBlock, and the input and output signatures match the official implementation. We will package it in a class and create MaxViT variants. Will send over a PR once done. :)

@DavidLandup0, do you have anything more to add?

cc: @innat @bhack @tanzhenyu

DavidLandup0 commented 1 year ago

Thanks for tagging and awesome work on RelativeMultiHeadAttention! Question for the Keras team - do we want to make RelativeMultiHeadAttention part of core Keras? MHA already is, and the relative variant is general enough for it, IMO.

Since we should package the components for review first, it's enough to have a rough model for the first PR to prove that they work, and assess their usage. I'll do the MaxViTTransformerEncoder and we can open the components PR.

It'd be a good idea to see if we can generalize the existing transformer encoder to be used between ViTs and MaxViTs since they're not too different (and allow the type of multihead attention to be changed). The main counter argument is that it already has quite a few arguments so having a general encoder with many might not be very user friendly.

Thoughts?

bhack commented 1 year ago

The main counter argument is that it already has quite a few arguments so having a general encoder with many might not be very user friendly.

Generally this could be an indirect signal that it could require a base class.

DavidLandup0 commented 1 year ago

For reference, this is the constructor:

def __init__(
        self,
        project_dim,
        num_heads,
        mlp_dim,
        mlp_dropout=0.1,
        attention_dropout=0.1,
        activation=tf.keras.activations.gelu,
        layer_norm_epsilon=1e-06,
        attention_type='mha',
        **kwargs,
    ):

Though, because of the defaults, usage can be as simple as:

keras_cv.layers.TransformerEncoder(project_dim=project_dim,
                                           mlp_dim = mlp_dim,
                                           num_heads=num_heads)(encoded_patches)

Now - I remember KerasNLP having this same issue. We might not be able to have a fully general TransformerEncoder for all cases, so it might be better to do them separately?

In the case of MaxViT, it's one extra arg, that simply defines:

        if attention_type == 'mha':
            attention_layer = layers.MultiHeadAttention
        elif attention_type == 'relmha':
            attention_layer = layers.RelativeMultiHeadAttention

So it's a small change. The question is mainly for work down the line when we might need to support more options.

ayulockin commented 1 year ago

I am in favour of a separate TransformerEncoder. It allows for speedy implementation since vision transformers rapidly evolve.

The counterargument is that we implement a handful of vision transformers and then try to build a unified transformer encoder by introducing a base class.

tanzhenyu commented 1 year ago

Here is a quick update on the work done so far:

Work done in collaboration with @DavidLandup0 :)

We have almost all the components - WindowPartition, UnWindowPartition, GridPartition, UnGridPartition and RelativeMultiHeadAttention done.

We have stacked them together to build a barebone MaxViTBlock, and the input and output signatures match the official implementation. We will package it in a class and create MaxViT variants. Will send over a PR once done. :)

@DavidLandup0, do you have anything more to add?

cc: @innat @bhack @tanzhenyu

Great progress! The breakdown of those components sounds good to me. @vztu @Yinxiaoli can you comment here?

Re David's question -- I think it'd be nice to have a transformer encoder that accept different attention mechanisms, though we don't have plan to move relative attention to core keras yet -- maybe later, given there are so many different attentions out there. If MaxVit can re-use the encoder that'd be great, the core value of KCV is always to provide generic components.