sovrasov / flops-counter.pytorch

Flops counter for convolutional networks in pytorch framework
MIT License
2.83k stars 306 forks source link

Is there some bug in the 'input_constructor' function? #117

Closed CanChengZheng closed 1 year ago

CanChengZheng commented 1 year ago

I encountered the following exception when using the input_constructor function to build the input:

Flops estimation was not finished successfully because ofthe following exception:
<class 'UnboundLocalError'> : local variable 'batch' referenced before assignment
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/Fas_PatchNet_Multiplier/lib/python3.10/site-packages/ptflops/pytorch_engine.py", line 60, in get_flops_pytorch
    _ = flops_model(batch)
UnboundLocalError: local variable 'batch' referenced before assignment

Here is my analysis, I hope it is helpful to you!

  1. In pytorch_engine.py, when I use input_constructor to build the input, it leads to the batch being unassigned.

    if input_constructor:
        input = input_constructor(input_res)
        _ = flops_model(**input)
    else:
        try:
            batch = torch.ones(()).new_empty((1, *input_res),
                                             dtype=next(flops_model.parameters()).dtype,
                                             device=next(flops_model.parameters()).device)
        except StopIteration:
            batch = torch.ones(()).new_empty((1, *input_res))
  2. The code _ = flops_model(batch) will be executed next (as shown below). At this point, an exception is thrown because the batch variable is not unassigned. This causes the flop calculation to be interrupted and returns None, None.

    try:
        _ = flops_model(batch)
        flops_count, params_count = flops_model.compute_average_flops_cost()
        flops_count += sum(torch_functional_flops)
        flops_count += sum(torch_tensor_ops_flops)
    except Exception as e:
        print("Flops estimation was not finished successfully because of"
              f"the following exception:\n{type(e)} : {e}")
        traceback.print_exc()
        reset_environment()
        return None, None
  3. I checked the code update history, and it appears that the vulnerability was introduced in the code modification at git sha <2cb2de11>.

Could you please verify if there is a bug in this section? Thank you!

CanChengZheng commented 1 year ago

git sha: 2cb2de11

sovrasov commented 1 year ago

@CanChengZheng thanks for reporting, corresponding fix will be merged soon