Peterande / SAST

CVPR 2024: SAST: Scene Adaptive Sparse Transformer for Event-based Object Detection
MIT License
24 stars 2 forks source link

question #1

Closed fengwei0907 closed 4 months ago

fengwei0907 commented 4 months ago

How is the key performance indicator FLOPs (G) calculated in the paper? For example, in Table 1 of the paper, the RVT on the gen1 dataset is 3.5GFLOPs, or in other words, where is the code used to calculate this indicator? After reproducing RVT, I only obtained time in seconds.

Peterande commented 4 months ago

You can follow "https://github.com/Peterande/SAST/blob/master/benchmark.py" to calculate A-FLOPs: Adding lines to train.py or validation.py:

from .benchmark import compute_fps, compute_gflops

    model = fetch_model_module(config=config).cuda()
    dataset = fetch_data_module(config=config)
    dataset.setup(stage='test')
    dataset = dataset.test_dataset
    # model = torch.compile(model)
    if utils.is_main_process() and config.benchmark:
        n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        fps = compute_fps(model, dataset, num_iters=100, batch_size=1, sparsity=sparsity)
        bfps = compute_fps(model, dataset, num_iters=100, batch_size=1, sparsity=sparsity)
        if config.benchmark_only:
            gflops = compute_gflops(model, dataset, approximated=True, sparsity=sparsity)
        else:
            gflops = compute_gflops(model, dataset, approximated=True, sparsity=sparsity)

        # model = torch.jit.script(model)
        # model = torch.jit.trace(model, images, strict=False)
        tab_keys = ["#Params(M)", "GFLOPs", "FPS", "B4FPS"]
        tab_vals = [n_params / 10 ** 6, gflops, fps, bfps]
        table = tabulate([tab_vals], headers=tab_keys, tablefmt="pipe",
                        floatfmt=".3f", stralign="center", numalign="center")
        print("===== Benchmark (Crude Approx.) =====\n" + table)
    model = None

I removed that part because it was very time-consuming.

fengwei0907 commented 4 months ago

I got it ! Thank you!

fengwei0907 commented 4 months ago

Hello! I have another question.After I reproduce RVT, I have the following reasoning time: Loading and preparing results DONE (t=0.41s) creating index Index created! Running per image evaluation Evaluate annotation type bbox DONE (t=7.23s) Accumulating evaluation results DONE (t=2.31s) May I ask how to infer the time velocity in ms, just like 16 ms in your paper on the 1 Mpx dataset. I am using a 3090TIGPU with a batch_size of 1, so what should be the speed in ms?

Peterande commented 4 months ago

You can follow "https://github.com/Peterande/SAST/blob/master/benchmark.py" to calculate A-FLOPs: Adding lines to train.py or validation.py:

from .benchmark import compute_fps, compute_gflops

    model = fetch_model_module(config=config).cuda()
    dataset = fetch_data_module(config=config)
    dataset.setup(stage='test')
    dataset = dataset.test_dataset
    # model = torch.compile(model)
    if utils.is_main_process() and config.benchmark:
        n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        fps = compute_fps(model, dataset, num_iters=100, batch_size=1, sparsity=sparsity)
        bfps = compute_fps(model, dataset, num_iters=100, batch_size=1, sparsity=sparsity)
        if config.benchmark_only:
            gflops = compute_gflops(model, dataset, approximated=True, sparsity=sparsity)
        else:
            gflops = compute_gflops(model, dataset, approximated=True, sparsity=sparsity)

        # model = torch.jit.script(model)
        # model = torch.jit.trace(model, images, strict=False)
        tab_keys = ["#Params(M)", "GFLOPs", "FPS", "B4FPS"]
        tab_vals = [n_params / 10 ** 6, gflops, fps, bfps]
        table = tabulate([tab_vals], headers=tab_keys, tablefmt="pipe",
                        floatfmt=".3f", stralign="center", numalign="center")
        print("===== Benchmark (Crude Approx.) =====\n" + table)
    model = None

I removed that part because it was very time-consuming.

Speed is tested here, not during evaluation process. Most time is wasted on calculating metrics during evaluation.

Peterande commented 4 months ago

try compute_fps

fengwei0907 commented 4 months ago

Okay, because my experiment is still running, I'm not sure if this code can run.I'll get back to you after I test it. Thank you for your reply.

fengwei0907 commented 4 months ago

Hello! I tried your code, I just calculated the metric gflops. Then I would like to ask, what you mentioned in your article about taking the first 1000 samples of the gen1 test dataset for calculation? I would like to ask how it is implemented specifically. Below is a screenshot of my experiment process. Is it enough to run it after completion? 实验问题.docx

Peterande commented 4 months ago

109 is the index of the sample within the sequence. You do not need to calculate the floating-point operations (FLOPs) for all 470 sequences; you can break out of the loop after calculating for 1000 samples. By the way, the average FLOPs obtained from each training may differ from those in the paper. This is because the training strategy of RVT cannot fix the random seed, resulting in different sparsity levels being learned by the network each time.

fengwei0907 commented 4 months ago

Does each sequence in the 470 sequence contain 109 samples? Do we terminate the program around 9/470?

fengwei0907 commented 4 months ago

I don't know if it's due to my parameter settings, why can't I get a Gflops of 3.5 for RVT? Did you retrain RVT and measure Gflops?I just get 5.129G flops for RVT. I do not retrain RVT, i just use its checkpoint.

Peterande commented 4 months ago

3.5 GFLOPs. Which dataset? RVT or SAST?

fengwei0907 commented 4 months ago

3.5 GFLOPs. Which dataset? RVT or SAST?

gen1 RVT

Peterande commented 4 months ago

Try measuring the backbone only.

Peterande commented 4 months ago

Read the caption of Table 1.

fengwei0907 commented 4 months ago

okok,how to set to only measure the backbone ? my bro.

Peterande commented 4 months ago

https://github.com/Peterande/SAST/blob/master/modules/detection.py 'forward'

fengwei0907 commented 4 months ago

And I found that using a photo to estimate Gflops is the same as using 1000 samples to measure, why is this? hhh

Peterande commented 4 months ago

I think you are probably using the original code of RVT.

Peterande commented 4 months ago

Only for sparse networks, the results will be different

Peterande commented 4 months ago

If if you are confusing about the problem of RVT, please ask the author of that paper

fengwei0907 commented 4 months ago

I made modifications based on RVT. But when measuring Gflops, I wanted to try your method, so I used the original code of RVT. Thank you, I'll take a look again.

fengwei0907 commented 4 months ago

Thank you, I have obtained the similar GFLOPS of RVT in your paper.

Peterande commented 4 months ago

You are welcome

Zizzzzzzz commented 2 months ago

You can follow "https://github.com/Peterande/SAST/blob/master/benchmark.py" to calculate A-FLOPs: Adding lines to train.py or validation.py:

from .benchmark import compute_fps, compute_gflops

    model = fetch_model_module(config=config).cuda()
    dataset = fetch_data_module(config=config)
    dataset.setup(stage='test')
    dataset = dataset.test_dataset
    # model = torch.compile(model)
    if utils.is_main_process() and config.benchmark:
        n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        fps = compute_fps(model, dataset, num_iters=100, batch_size=1, sparsity=sparsity)
        bfps = compute_fps(model, dataset, num_iters=100, batch_size=1, sparsity=sparsity)
        if config.benchmark_only:
            gflops = compute_gflops(model, dataset, approximated=True, sparsity=sparsity)
        else:
            gflops = compute_gflops(model, dataset, approximated=True, sparsity=sparsity)

        # model = torch.jit.script(model)
        # model = torch.jit.trace(model, images, strict=False)
        tab_keys = ["#Params(M)", "GFLOPs", "FPS", "B4FPS"]
        tab_vals = [n_params / 10 ** 6, gflops, fps, bfps]
        table = tabulate([tab_vals], headers=tab_keys, tablefmt="pipe",
                        floatfmt=".3f", stralign="center", numalign="center")
        print("===== Benchmark (Crude Approx.) =====\n" + table)
    model = None

I removed that part because it was very time-consuming.

Hello, how should sparsity be obtained when calculating FPS and FLOPS for SAST?