zugexiaodui / torch_flops

A library for calculating the FLOPs in the forward() process based on torch.fx
MIT License
87 stars 2 forks source link

Flops For GFP-GAN(StyleGan) #7

Closed Liar-zzy closed 10 months ago

Liar-zzy commented 10 months ago

Impressive work! I have encountered some questions regarding custom operators while calculating the FLOPs (floating-point operations) for GFP-GAN. GFP-GAN utilizes custom operators from StyleGAN , such as fused_act, among others.

My Error logs:

Traceback (most recent call last):
  File "/data/basicFR/4gfp-gan/test_thop.py", line 94, in <module>
    main()
  File "/data/basicFR/4gfp-gan/test_thop.py", line 85, in main
    flops_counter = TorchFLOPsByFX(model)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch_flops/flops_engine.py", line 338, in __init__
    raise e
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch_flops/flops_engine.py", line 331, in __init__
    self.graph_model: GraphModule = symbolic_trace(model)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 857, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 566, in trace
    self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
  File "/data/basicFR/4gfp-gan/gfpgan/archs/gfpganv1_arch.py", line 368, in forward
    for i in range(self.log_size - 2):
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 556, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 372, in call_module
    return forward(*args, **kwargs)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 552, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 556, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 372, in call_module
    return forward(*args, **kwargs)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 552, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/basicsr/ops/fused_act/fused_act.py", line 93, in forward
    return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/basicsr/ops/fused_act/fused_act.py", line 97, in fused_leaky_relu
    return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
  File "/home/amax/.conda/envs/gfp-bfr/lib/python3.10/site-packages/basicsr/ops/fused_act/fused_act.py", line 67, in forward
    out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
TypeError: fused_bias_act(): incompatible function arguments. The following argument types are supported:
    1. (arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: int, arg4: int, arg5: float, arg6: float) -> at::Tensor

Invoked with: Proxy(conv2d), Proxy(conv_body_first_1_bias), Proxy(new_empty), 3, 0, 0.2, 1.4142135623730951

My Running Code

x = torch.randn([1, 3, 512, 512]).to(device)
with torch.no_grad():
    # Build the graph of the model. You can specify the operations (listed in `MODULE_FLOPs_MAPPING`, `FUNCTION_FLOPs_MAPPING` and `METHOD_FLOPs_MAPPING` in 'flops_ops.py') to ignore.
    flops_counter = TorchFLOPsByFX(model)
    flops_counter.propagate(x)
# # Print the flops of each node in the graph. Note that if there are unsupported operations, the "flops" of these ops will be marked as 'not recognized'.
print('*' * 120)
flops_counter.print_result_table()
# # Print the total FLOPs
total_flops = flops_counter.print_total_flops(show=True)

I will prioritize addressing this issue during the course of this week as it holds significant importance for me. Furthermore, I express my desire to actively participate in this project.

Liar-zzy commented 10 months ago

If you encounter any issues while executing GFP-GAN, please feel free to contact me via WeChat. Thanks a lot~

Liar-zzy commented 10 months ago

I have sovle this problem