zugexiaodui / torch_flops

A library for calculating the FLOPs in the forward() process based on torch.fx
MIT License
70 stars 2 forks source link

StopIteration Issue #15

Closed mitkotak closed 1 month ago

mitkotak commented 1 month ago

Thanks for the great work !

Was wondering what this issue means

StopIteration                             Traceback (most recent call last)
[<ipython-input-163-ff60ccfb7d72>](https://1hti6okzrmc-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240725-060215_RC00_655912111#) in <cell line: 2>()
      1 flops_counter = TorchFLOPsByFX(model)
----> 2 flops_counter.propagate(x)

1 frames
[/usr/local/lib/python3.10/dist-packages/torch_flops/flops_engine.py](https://1hti6okzrmc-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240725-060215_RC00_655912111#) in propagate(self, *args)
    402     def propagate(self, *args):
    403         ShapeProp(self.graph_model, mem_func_name=self.mem_func_name, ignore_ops=self.ignore_ops).propagate(*args)
--> 404 
    405         result_table = []
    406         for node in self.graph_model.graph.nodes:

[/usr/local/lib/python3.10/dist-packages/torch_flops/flops_engine.py](https://1hti6okzrmc-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240725-060215_RC00_655912111#) in __init__(self, gm, **kwargs)
    154         self.ignore_ops = ignore_ops
    155         self.mem_func_name = mem_func_name
--> 156         self.device = next(gm.parameters()).device
    157     @compatibility(is_backward_compatible=True)
    158     def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:

StopIteration:
zugexiaodui commented 1 month ago

Sorry for the late reply. Could you provide more detailed error reports? I cannot locate the code causing this error.

mitkotak commented 1 month ago

Appreciate your input: Here’s a notebook to reproduce with my setup: https://colab.research.google.com/drive/1ojTHd34sqEwS7r1YDl0fBF3jqNx1yd9z?usp=sharing

On Jul 30, 2024, at 7:04 AM, LUYue @.***> wrote:

Sorry for the late reply. Could you provide more detailed error reports? I cannot locate the code causing this error.

— Reply to this email directly, view it on GitHub https://github.com/zugexiaodui/torch_flops/issues/15#issuecomment-2258434854, or unsubscribe https://github.com/notifications/unsubscribe-auth/AMXP5DGE2WOH76IVUAAHTATZO6MOHAVCNFSM6AAAAABLSNPMC2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENJYGQZTIOBVGQ. You are receiving this because you authored the thread.

mitkotak commented 1 month ago

Sorry for not pasting the full trace earlier, but it was coming from the shape propagation

---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
[<ipython-input-19-122f84ff10dc>](https://ou71w9f5kyb-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240729-060216_RC00_657091656#) in <cell line: 6>()
      4 from torch_flops import TorchFLOPsByFX
      5 flops_counter = TorchFLOPsByFX(model)
----> 6 flops_counter.propagate(x)

1 frames
[/usr/local/lib/python3.10/dist-packages/torch_flops/flops_engine.py](https://ou71w9f5kyb-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240729-060216_RC00_657091656#) in propagate(self, *args)
    402 
    403     def propagate(self, *args):
--> 404         ShapeProp(self.graph_model, mem_func_name=self.mem_func_name, ignore_ops=self.ignore_ops).propagate(*args)
    405 
    406         result_table = []

[/usr/local/lib/python3.10/dist-packages/torch_flops/flops_engine.py](https://ou71w9f5kyb-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240729-060216_RC00_657091656#) in __init__(self, gm, **kwargs)
    154         self.ignore_ops = ignore_ops
    155         self.mem_func_name = mem_func_name
--> 156         self.device = next(gm.parameters()).device
    157 
    158     @compatibility(is_backward_compatible=True)

StopIteration:
zugexiaodui commented 1 month ago

It seems that self.device = next(gm.parameters()).device cannot be executed. It may be caused by next(gm.parameters()), wherein your model is built on the e3nn lib. A possible solution is avoiding using next(gm.parameters()). You may modify the source code in https://github.com/zugexiaodui/torch_flops/blob/dc72fb62934e107987fc3e9cb59d74d32b3910ef/torch_flops/flops_engine.py#L156 and directly specify the self.device.

mitkotak commented 1 month ago

Appreciate the input. I was hitting on this error next:

Traceback (most recent call last):
  File "<ipython-input-19-ba1150c61b29>", line 304, in run_node
    result, flops, exec_time = getattr(self, n.op)(n.target, args, kwargs)
  File "<ipython-input-19-ba1150c61b29>", line 235, in call_function
    flops = FUNCTION_FLOPs_MAPPING[func_name](result, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_flops/flops_ops.py", line 200, in FunctionFLOPs_elemwise
    raise TypeError(type(result))
TypeError: <class 'torch.Size'>
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-19-ba1150c61b29>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in run_node(self, n)
    303                         if n.op in ('call_module', 'call_function', 'call_method'):
--> 304                             result, flops, exec_time = getattr(self, n.op)(n.target, args, kwargs)
    305                         else:

6 frames
[<ipython-input-19-ba1150c61b29>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in call_function(self, target, args, kwargs)
    234             if func_name not in self.ignore_ops:
--> 235                 flops = FUNCTION_FLOPs_MAPPING[func_name](result, *args, **kwargs)
    236             else:

[/usr/local/lib/python3.10/dist-packages/torch_flops/flops_ops.py](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in FunctionFLOPs_elemwise(result, *args, **kwargs)
    199     else:
--> 200         raise TypeError(type(result))
    201 

TypeError: <class 'torch.Size'>

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-20-e2b7355d83be>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in <cell line: 5>()
      3 
      4 flops_counter = TorchFLOPsByFX(model)
----> 5 flops_counter.propagate(x)

[<ipython-input-19-ba1150c61b29>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in propagate(self, *args)
    402 
    403     def propagate(self, *args):
--> 404         ShapeProp(self.graph_model, mem_func_name=self.mem_func_name, ignore_ops=self.ignore_ops).propagate(*args)
    405 
    406         result_table = []

[<ipython-input-19-ba1150c61b29>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in propagate(self, *args)
    369         else:
    370             fake_args = args
--> 371         return super().run(*fake_args)
    372 
    373 

[/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in run(self, initial_env, enable_io_processing, *args)
    143 
    144             try:
--> 145                 self.env[node] = self.run_node(node)
    146             except Exception as e:
    147                 if self.extra_traceback:

[<ipython-input-19-ba1150c61b29>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in run_node(self, n)
    330         except Exception as e:
    331             traceback.print_exc()
--> 332             raise RuntimeError(
    333                 f"ShapeProp error for: node={n.format_node()} with "
    334                 f"meta={n.meta}"

RuntimeError: ShapeProp error for: node=%add : [num_users=1] = call_function[target=operator.add](args = (%getitem, (4,)), kwargs = {}) with meta={'nn_module_stack': OrderedDict([('_compiled_main', ('_compiled_main', <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>))])}

While executing %add : [num_users=1] = call_function[target=operator.add](args = (%getitem, (4,)), kwargs = {})
Original traceback:
None

Updated the collab

zugexiaodui commented 1 month ago

The error is raised by model(x, x). Please check the input of the model. @mitkotak image

mitkotak commented 1 month ago

Gotcha sorry for that. Yup got it to work. Thanks for the patience !