rstrudel / segmenter

[ICCV2021] Official PyTorch implementation of Segmenter: Transformer for Semantic Segmentation
MIT License
846 stars 175 forks source link

Code to compute images/sec #44

Closed prabhuteja12 closed 2 years ago

prabhuteja12 commented 2 years ago

Hi,

Thank you for the cool work!

I see that you report images/sec, and mention the following in the paper:

To compute the images per second, we use a V100 GPU, fix the image resolution to 512 and for each model we maximize the batch size allowed by memory for a fair comparison.

I'm trying to do the same, however I'm unable to reproduce the numbers you of images/sec in the paper.

I'm using the code snippet from PyTorch as follows:

    batch = torch.rand(args.batch_size, *input_shape).cuda()
    model(batch)
    n_runs = 10
    from torch.utils.benchmark import Timer

    t = Timer(stmt="model.forward(batch)", globals={"model": model, "batch": batch})
    m = t.timeit(n_runs)

The batch size that fits on V100 for Vit-T backbone is about 140. And the above code shows a timing of 0.62 seconds. So I'm computing the total images/sec = 140/0.62 = 225.8. This is almost half the numbers in Table 3. Can you please help me with what I need to do to get the mentioned result?

Thank you!

rstrudel commented 2 years ago

Hi @prabhuteja12 ,

To perform robust benchmarking you can check timm repository, in particular the following file:

https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py#L244

Two things I do not see in your snippet and you can spot in timm are:

  1. warmup steps
  2. torch.no_grad

I hope this helps!

prabhuteja12 commented 2 years ago

Hi @rstrudel ,

Thank you for the quick reply, and the pointers to timm.

Additionally, timm uses perf_counter and so does PyTorch.

So I'm unable to figure out what is causing this massive drop in speed. If you have access to your specific benchmarking script, can you share that?

rstrudel commented 2 years ago

The snippet of code is based on timm benchmarking and we do not plan to release a clean version of it as of now. Table 3 refers to Segmenter models with a linear decoder, did you make sure you measured timings with this model and not the mask decoder (which is more expensive)? I will try to check this in the following weeks.

prabhuteja12 commented 2 years ago

@rstrudel I think I was able to get the numbers (in that range). The data and the model need to be in float16. Please confirm that this is what you did as well.

Thanks!

rstrudel commented 2 years ago

Here is the snippet of the code I used to check throughput. So indeed good call I was using mixed precision, well spotted!

Your use of Timer is probably a cleaner way of doing it, not sure I had it (at least did not know about it) at the time of the experiments.

@torch.no_grad()
def compute_throughput(model, batch_size, resolution):
    torch.cuda.empty_cache()
    warmup_iters = 3
    num_iters = 30

    timing = []

    x = torch.randn(batch_size, 3, resolution, resolution, device=device)

    # warmup
    for _ in range(warmup_iters):
        with torch.cuda.amp.autocast():
            y = model(x)

    MB = 1024.0 * 1024.0
    torch.cuda.synchronize()
    for i in range(num_iters):
        memory = int(torch.cuda.max_memory_allocated() / MB)
        if i == 0:
            print(f"memory: {memory} MB")
        start = time.time()
        with torch.cuda.amp.autocast():
            y = model(x)
        torch.cuda.synchronize()
        timing.append(time.time() - start)

    timing = torch.as_tensor(timing, dtype=torch.float32)
    return batch_size / timing.mean()
prabhuteja12 commented 2 years ago

Great! Thank you!

KyloRen1 commented 1 year ago

Hello,

I'm trying to benchmark the speed for the ViT-tiny backbone with mask decoder. However, I'm only able to reach 54 samples per batch, even though there seems to be available memory for more. While benchmarking on a Tesla V100 32 GB using torch versions 2.0.1 and 1.13.1, I encountered the following error:

RuntimeError: upsample_bilinear2d_nhwc only supports output tensors with less than INT_MAX elements

Did you have the following error, and if so were you able to overcome it? Thank you