NVlabs / MambaVision

Official PyTorch Implementation of MambaVision: A Hybrid Mamba-Transformer Vision Backbone
https://arxiv.org/abs/2407.08083
Other
675 stars 37 forks source link

About your architecture.. #22

Closed LiamSumin closed 1 month ago

LiamSumin commented 1 month ago

Hello, Ali

Thank you for your good architectural design and I would like to use it for my research.

In that sense, I have a question from your network.

In Stage 3 and Stage 4, your network performs state space modeling through Mambavison mixer and then proceeds with Self-Attention to extract features.

Have you ever experimented with all stages in the same way as the architecture of Stage 3 and 4?

In addition, your Stage 1 and 2 used Convolution Block to extract local features. With a similar architecture, Deformable ViT used Swift window attention proposed by Swin Transformer for Stage 1 and 2 for complexity of operations.

So have you ever done an experiment using a Transformer block other than CNN Block in Stage 1 and 2?

If it is difficult to reply here, please reply to tnals4026@chungbuk.ac.kr !

Thank you.

ahatamiz commented 1 month ago

Hi @LiamSumin

Thank you for the question. We did not use stage 3/4 design for stage 1/2 due to our specific design principles for accomplishing optimal accuracy vs throughput tradeoff using a GPU. In this case, we identify the nature of each stage, in terms of being memory bound or compute bound and tailor specific operations that can be used in each to maximize our desired tradeoff (i.e. use compute intensive and expressive operations for memory and compute bound stages, respectively)

Specifically, stage 1 and 2 are memory bound due to having larger input dimensions with smaller channel dimensions. As a result, compute intensive operations such as dense convolution are more suited for these layers. Other operations which don't take advantage of matrix manipulation (tensor rolling, window/dewindow, etc.) should be minimized at these stages due to their extra transfer cost. Furthermore, expressive operations such as SSM/Self-attention are suited for stage 3 and 4 which are compute bound.

The above design principle allows for maximizing throughput and utilization by considering how accelerated computing works on GPUs.

I hope this helps.

LiamSumin commented 1 month ago

Thank you for your answering!

I have one more Question about the inference speed.

When I measured the speed of your model with the code below, the speed below came out.

device = torch.device("cuda") model = mamba_vision_B() model.to(device) x = torch.rand(32, 3, 224, 224).to(device) import time start_t = time.time() y = model(x) print(f"FPS: {1/(time.time() - start_t)}")

FPS: 5.955263509119668

In the case of the model with Mamba, I expected the inference speed to be fast, but it wasn't. I wonder if I made a good measurement.

Have you ever measured the Inference speed for sub-tasks using your MambaVision as a backbone?

And if this is really the inference speed of the model, how can we achieve faster inference speed while maintaining accuracy?

Thank you.

ahatamiz commented 1 month ago

Hi @LiamSumin

It is not clear what FPS denotes here, in your above snippet. Traditionally, FPS stands for Frame Per Seconds which is not applicable to this use case.

I have a snippet here that allows you to measure throughput in a systematic way (with proper warm up, etc.). In addition, we recommend A100 as our optimized hardware of choice to fully realize the benefits of our architecture. For example, Flash attention and SSM implementation great benefit from such GPU architecture:

model = timm.create_model('mamba_vision_T', pretrained=True)
input_data = torch.randn((bs, 3, resolution, resolution), device='cuda').cuda()
input_data = input_data.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)
model.cuda()
model.eval()

# warm up
with torch.cuda.amp.autocast():
    for ii in range(100):
        with torch.no_grad():
            output = model(input_data)

# speed
import time
import numpy as np

timer = []
start_time = time.time()
runs=500
with torch.cuda.amp.autocast(True):
    for ii in range(runs):
        start_time_loc = time.time()
        with torch.no_grad():
            output = model(input_data)

        timer.append(time.time()-start_time_loc)
    torch.cuda.synchronize()
end_time = time.time()

print(f"Throughput {int(bs * 1.0 / ((np.median(timer))))}")

I hope this helps.