NVlabs / MambaVision

Official PyTorch Implementation of MambaVision: A Hybrid Mamba-Transformer Vision Backbone
https://arxiv.org/abs/2407.08083
Other
792 stars 43 forks source link

mamba_vision_T is not as fast as expected #13

Closed emailweixu closed 3 months ago

emailweixu commented 3 months ago

Thanks for making MambaVision, I was excited to use it for faster training.

However, on a machine with 3090 GPU, I got similar thoughput for mamba_vision_T and convnext_tiny: mamba_vision_T: 2916 convnext_tiny: 2680

However, in Table 1 of the paper, the throughput of mamba_vision_T is almost twice as that of convnext_tiny. Is this expected?

The following is the code I used to get the throughput

import torch
import timm
import time

# model = timm.create_model('mamba_vision_T', pretrained=True)
model = timm.create_model('convnext_tiny', pretrained=True)
model = model.cuda() # place model on cuda

x=torch.randn(128, 3, 224, 224, device='cuda')

start = time.time()
with torch.cuda.amp.autocast(True):
    with torch.no_grad():
        for i in range(100):
            y2 = model(x)

t = time.time() - start
print("throughput:", 100*x.shape[0] / t)
ahatamiz commented 3 months ago

Thanks for the note. We performed the measurements on A100 GPU. There are certain optimizations for A100 that may cause the discrepancy.

Additionally, for measuring the throughput, it is recommended to use 100 warmups , and increase the number of cycles to 500 instead of 100 in the code shown above.

Lastly, I recommend passing channel last layout. All in all, it should look like this:

model = timm.create_model('mamba_vision_T', pretrained=True)
input_data = torch.randn((bs, 3, resolution, resolution), device='cuda').cuda()
input_data = input_data.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)
model.cuda()
model.eval()

# warm up
with torch.cuda.amp.autocast():
    for ii in range(100):
        with torch.no_grad():
            output = model(input_data)

# speed
import time
import numpy as np

timer = []
start_time = time.time()
runs=500
with torch.cuda.amp.autocast(True):
    for ii in range(runs):
        start_time_loc = time.time()
        with torch.no_grad():
            output = model(input_data)

        timer.append(time.time()-start_time_loc)
    torch.cuda.synchronize()
end_time = time.time()

print(f"Throughput {int(bs * 1.0 / ((np.median(timer))))}")
emailweixu commented 3 months ago

Thanks for the prompt reply. Here is the updated throughput:

convnext_tiny channel_last=False 2533 convnext_tiny channel_last=True 2532 mamba_vision_T channel_last=False 2922 mamba_vision_T channel_last=True 3355

channel_last does help for mamba_vision_T. It was still much slower than the resnet18 I am using right now. Do you think it is possible to make a MambaVision model faster than resnet18 with comparable accuracy? In my use case, speed is very important.

ahatamiz commented 3 months ago

Thanks @emailweixu. I would like to re-iterate that our specific hardware design, A100, is important to fully realize the potential of MambaVision and achieve the optimal (intended) accuracy vs throughput tradeoff. For example, A100 has 32% more tensor cores than the 3090 that is used in your instance.

Regarding Resnet 18 model, its Top-1 accuracy in torchhub is reported to be 69.76% according to here. It is certainly possible to match this level of accuracy while being comparable or faster to this model. However, we have not targeted it yet (mostly due to very low level of accuracy which makes it hard to argue a usecase).

emailweixu commented 3 months ago

Thanks @ahatamiz, I agree with you that perhaps both FlashAttention and MambaSSM are better optimized for A100. Actually, I would like a model as fast as resnet18 but with better accuracy.

saarthakk-insitro commented 5 days ago

Thanks for the note. We performed the measurements on A100 GPU. There are certain optimizations for A100 that may cause the discrepancy.

Additionally, for measuring the throughput, it is recommended to use 100 warmups , and increase the number of cycles to 500 instead of 100 in the code shown above.

Lastly, I recommend passing channel last layout. All in all, it should look like this:

model = timm.create_model('mamba_vision_T', pretrained=True)
input_data = torch.randn((bs, 3, resolution, resolution), device='cuda').cuda()
input_data = input_data.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)
model.cuda()
model.eval()

# warm up
with torch.cuda.amp.autocast():
    for ii in range(100):
        with torch.no_grad():
            output = model(input_data)

# speed
import time
import numpy as np

timer = []
start_time = time.time()
runs=500
with torch.cuda.amp.autocast(True):
    for ii in range(runs):
        start_time_loc = time.time()
        with torch.no_grad():
            output = model(input_data)

        timer.append(time.time()-start_time_loc)
    torch.cuda.synchronize()
end_time = time.time()

print(f"Throughput {int(bs * 1.0 / ((np.median(timer))))}")

Hi Authors, thank you for making the code public and sharing thoughts on throughput calculations.

I was going through some blogs (for eg. https://towardsdatascience.com/the-correct-way-to-measure-inference-time-of-deep-neural-networks-304a54e5187f) and throughput calculations from VMamba github (https://github.com/MzeroMiko/VMamba/blob/main/analyze/utils.py#L481), and I believe reporting "print(f"Throughput {int(bs * 1.0 / ((np.median(timer))))}")" as you suggested could potentially give inflated throughput speed since syncronize isn't called before appending the time in timer list.

I tried benchmarking on A100 with batch size 128, and found the throughput of ViM to be much lower than reported in your work (2357 for ViM-T, 1219 for ViM-S both with autocast turned on, whereas reported is around 4000 for ViM-T, and 2000 for ViM-S).

More through benchmark: throughput vim: use_autocast=True {'224_tiny': 2357, '224_small': 1219, '224_base': 600, '224_large': 211} throughput vim: use_autocast=False {'224_tiny': 1854, '224_small': 782, '224_base': 287, '224_large': 91}

The code I used for throughput analysis:

bs=128
runs=500
warmup_runs=100
use_autocast = True # False

model = model.to(memory_format=torch.channels_last)
model.cuda()
model.eval()

if use_autocast is True:
    with torch.cuda.amp.autocast():
        for ii in range(warmup_runs):
            input_data = torch.randn((bs, 3, resolution, resolution), device='cuda').cuda()
            input_data = input_data.to(memory_format=torch.channels_last)
            with torch.no_grad():
                output = model(input_data)
            del input_data
else:
    for ii in range(warmup_runs):
        input_data = torch.randn((bs, 3, resolution, resolution), device='cuda').cuda()
        input_data = input_data.to(memory_format=torch.channels_last)
        with torch.no_grad():
            output = model(input_data)
        del input_data

start_events = [torch.cuda.Event(enable_timing=True) for _ in range(runs)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(runs)]

if use_autocast is True:
    with torch.cuda.amp.autocast():
        for ii in range(runs):
            input_data = torch.randn((bs, 3, resolution, resolution), device='cuda').cuda()
            input_data = input_data.to(memory_format=torch.channels_last)

            with torch.no_grad():
                start_events[ii].record()
                output = model(input_data)
                end_events[ii].record()

            del input_data
            del output
else:
    for ii in range(runs):
        input_data = torch.randn((bs, 3, resolution, resolution), device='cuda').cuda()
        input_data = input_data.to(memory_format=torch.channels_last)

        with torch.no_grad():
            start_events[ii].record()
            output = model(input_data)
            end_events[ii].record()

        del input_data
        del output

torch.cuda.synchronize()
times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
int(bs * 1.0 / ((np.median(times)/1000)))

Can you please guide me if I am doing anything wrong? I am aiming for a submission in Nov, and therefore quickly need to verify the throughput setup.