Open mohamad-hasan-sohan-ajini opened 5 years ago
Hi
My model has some LSTM layers and the count_ops thrown the following error:
count_ops
In [37]: count_ops(model, x) --------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-37-846862382dff> in <module> ----> 1 count_ops(model, x) /usr/local/lib/python3.5/dist-packages/pthflops/ops.py in count_ops(model, input, custom_ops, ignore_layers, print_readable, verbose, *args) 212 # Convert pytorch module to ONNX 213 trace, _ = torch.jit.get_trace_graph(model, input, *args) --> 214 torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) 215 graph = trace.graph() 216 /usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py in _optimize_trace(trace, operator_export_type) 40 def _optimize_trace(trace, operator_export_type): 41 from torch.onnx import utils ---> 42 trace.set_graph(utils._optimize_graph(trace.graph(), operator_export_type)) 43 44 /usr/local/lib/python3.5/dist-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type) 153 154 if operator_export_type != OperatorExportTypes.RAW: --> 155 graph = torch._C._jit_pass_onnx(graph, operator_export_type) 156 torch._C._jit_pass_lint(graph) 157 torch._C._jit_pass_onnx_peephole(graph) /usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py in _run_symbolic_function(*args, **kwargs) 50 def _run_symbolic_function(*args, **kwargs): 51 from torch.onnx import utils ---> 52 return utils._run_symbolic_function(*args, **kwargs) 53 54 /usr/local/lib/python3.5/dist-packages/torch/onnx/utils.py in _run_symbolic_function(g, n, inputs, env, operator_export_type) 502 return None 503 fn = getattr(torch.onnx.symbolic, op_name) --> 504 return fn(g, *inputs, **attrs) 505 506 elif ns == "prim": /usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in lstm(g, *args) 1274 return _lstm_packed(g, *args) 1275 else: -> 1276 return _lstm_full(g, *args) 1277 1278 /usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in wrapper(g, *args) 87 assert len(arg_descriptors) == len(args) 88 args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] ---> 89 return fn(g, *args) 90 # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround 91 try: /usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first) 1260 hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v) 1261 return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers, -> 1262 dropout, train, bidirectional, batch_first) 1263 1264 /usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in _generic_rnn(g, variant, input, initial_states, all_weights, has_biases, num_layers, dropout, train, bidirectional, batch_first, batch_sizes) 1201 state_indices = i, i + 1 1202 else: -> 1203 weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) 1204 weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) 1205 /usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in transform_weights(layer_index) 1188 elif variant == 'GRU' or variant == 'LSTM': 1189 weight_ih, weight_hh, bias_ih, bias_hh = \ -> 1190 [reform_weights(g, w, hidden_size, reform_permutation) for w in layer_weights[layer_index]] 1191 bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0) 1192 ValueError: not enough values to unpack (expected 4, got 2)
Hi
My model has some LSTM layers and the
count_ops
thrown the following error: