Closed mitkotak closed 4 months ago
Sorry for the late reply. Could you provide more detailed error reports? I cannot locate the code causing this error.
Appreciate your input: Here’s a notebook to reproduce with my setup: https://colab.research.google.com/drive/1ojTHd34sqEwS7r1YDl0fBF3jqNx1yd9z?usp=sharing
On Jul 30, 2024, at 7:04 AM, LUYue @.***> wrote:
Sorry for the late reply. Could you provide more detailed error reports? I cannot locate the code causing this error.
— Reply to this email directly, view it on GitHub https://github.com/zugexiaodui/torch_flops/issues/15#issuecomment-2258434854, or unsubscribe https://github.com/notifications/unsubscribe-auth/AMXP5DGE2WOH76IVUAAHTATZO6MOHAVCNFSM6AAAAABLSNPMC2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENJYGQZTIOBVGQ. You are receiving this because you authored the thread.
Sorry for not pasting the full trace earlier, but it was coming from the shape propagation
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
[<ipython-input-19-122f84ff10dc>](https://ou71w9f5kyb-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240729-060216_RC00_657091656#) in <cell line: 6>()
4 from torch_flops import TorchFLOPsByFX
5 flops_counter = TorchFLOPsByFX(model)
----> 6 flops_counter.propagate(x)
1 frames
[/usr/local/lib/python3.10/dist-packages/torch_flops/flops_engine.py](https://ou71w9f5kyb-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240729-060216_RC00_657091656#) in propagate(self, *args)
402
403 def propagate(self, *args):
--> 404 ShapeProp(self.graph_model, mem_func_name=self.mem_func_name, ignore_ops=self.ignore_ops).propagate(*args)
405
406 result_table = []
[/usr/local/lib/python3.10/dist-packages/torch_flops/flops_engine.py](https://ou71w9f5kyb-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240729-060216_RC00_657091656#) in __init__(self, gm, **kwargs)
154 self.ignore_ops = ignore_ops
155 self.mem_func_name = mem_func_name
--> 156 self.device = next(gm.parameters()).device
157
158 @compatibility(is_backward_compatible=True)
StopIteration:
It seems that self.device = next(gm.parameters()).device
cannot be executed. It may be caused by next(gm.parameters())
, wherein your model is built on the e3nn
lib. A possible solution is avoiding using next(gm.parameters())
. You may modify the source code in https://github.com/zugexiaodui/torch_flops/blob/dc72fb62934e107987fc3e9cb59d74d32b3910ef/torch_flops/flops_engine.py#L156
and directly specify the self.device
.
Appreciate the input. I was hitting on this error next:
Traceback (most recent call last):
File "<ipython-input-19-ba1150c61b29>", line 304, in run_node
result, flops, exec_time = getattr(self, n.op)(n.target, args, kwargs)
File "<ipython-input-19-ba1150c61b29>", line 235, in call_function
flops = FUNCTION_FLOPs_MAPPING[func_name](result, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch_flops/flops_ops.py", line 200, in FunctionFLOPs_elemwise
raise TypeError(type(result))
TypeError: <class 'torch.Size'>
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-19-ba1150c61b29>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in run_node(self, n)
303 if n.op in ('call_module', 'call_function', 'call_method'):
--> 304 result, flops, exec_time = getattr(self, n.op)(n.target, args, kwargs)
305 else:
6 frames
[<ipython-input-19-ba1150c61b29>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in call_function(self, target, args, kwargs)
234 if func_name not in self.ignore_ops:
--> 235 flops = FUNCTION_FLOPs_MAPPING[func_name](result, *args, **kwargs)
236 else:
[/usr/local/lib/python3.10/dist-packages/torch_flops/flops_ops.py](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in FunctionFLOPs_elemwise(result, *args, **kwargs)
199 else:
--> 200 raise TypeError(type(result))
201
TypeError: <class 'torch.Size'>
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
[<ipython-input-20-e2b7355d83be>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in <cell line: 5>()
3
4 flops_counter = TorchFLOPsByFX(model)
----> 5 flops_counter.propagate(x)
[<ipython-input-19-ba1150c61b29>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in propagate(self, *args)
402
403 def propagate(self, *args):
--> 404 ShapeProp(self.graph_model, mem_func_name=self.mem_func_name, ignore_ops=self.ignore_ops).propagate(*args)
405
406 result_table = []
[<ipython-input-19-ba1150c61b29>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in propagate(self, *args)
369 else:
370 fake_args = args
--> 371 return super().run(*fake_args)
372
373
[/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in run(self, initial_env, enable_io_processing, *args)
143
144 try:
--> 145 self.env[node] = self.run_node(node)
146 except Exception as e:
147 if self.extra_traceback:
[<ipython-input-19-ba1150c61b29>](https://jx462ekhxvh-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240730-060116_RC00_657538329#) in run_node(self, n)
330 except Exception as e:
331 traceback.print_exc()
--> 332 raise RuntimeError(
333 f"ShapeProp error for: node={n.format_node()} with "
334 f"meta={n.meta}"
RuntimeError: ShapeProp error for: node=%add : [num_users=1] = call_function[target=operator.add](args = (%getitem, (4,)), kwargs = {}) with meta={'nn_module_stack': OrderedDict([('_compiled_main', ('_compiled_main', <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>))])}
While executing %add : [num_users=1] = call_function[target=operator.add](args = (%getitem, (4,)), kwargs = {})
Original traceback:
None
Updated the collab
The error is raised by model(x, x)
. Please check the input of the model. @mitkotak
Gotcha sorry for that. Yup got it to work. Thanks for the patience !
Thanks for the great work !
Was wondering what this issue means