1adrianb / pytorch-estimate-flops

Estimate/count FLOPS for a given neural network using pytorch
https://www.adrianbulat.com
BSD 3-Clause "New" or "Revised" License
305 stars 22 forks source link

LSTM support #2

Open mohamad-hasan-sohan-ajini opened 5 years ago

mohamad-hasan-sohan-ajini commented 5 years ago

Hi

My model has some LSTM layers and the count_ops thrown the following error:

In [37]: count_ops(model, x)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-37-846862382dff> in <module>
----> 1 count_ops(model, x)

/usr/local/lib/python3.5/dist-packages/pthflops/ops.py in count_ops(model, input, custom_ops, ignore_layers, print_readable, verbose, *args)
    212     # Convert pytorch module to ONNX
    213     trace, _ = torch.jit.get_trace_graph(model, input, *args)
--> 214     torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
    215     graph = trace.graph()
    216

/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py in _optimize_trace(trace, operator_export_type)
     40 def _optimize_trace(trace, operator_export_type):
     41     from torch.onnx import utils
---> 42     trace.set_graph(utils._optimize_graph(trace.graph(), operator_export_type))
     43
     44

/usr/local/lib/python3.5/dist-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type)
    153 
    154     if operator_export_type != OperatorExportTypes.RAW:
--> 155         graph = torch._C._jit_pass_onnx(graph, operator_export_type)
    156         torch._C._jit_pass_lint(graph)
    157         torch._C._jit_pass_onnx_peephole(graph)

/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py in _run_symbolic_function(*args, **kwargs)
     50 def _run_symbolic_function(*args, **kwargs):
     51     from torch.onnx import utils
---> 52     return utils._run_symbolic_function(*args, **kwargs)
     53 
     54 

/usr/local/lib/python3.5/dist-packages/torch/onnx/utils.py in _run_symbolic_function(g, n, inputs, env, operator_export_type)
    502                     return None
    503                 fn = getattr(torch.onnx.symbolic, op_name)
--> 504                 return fn(g, *inputs, **attrs)
    505 
    506         elif ns == "prim":

/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in lstm(g, *args)
   1274         return _lstm_packed(g, *args)
   1275     else:
-> 1276         return _lstm_full(g, *args)
   1277
   1278

/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in wrapper(g, *args)
     87             assert len(arg_descriptors) == len(args)
     88             args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
---> 89             return fn(g, *args)
     90         # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
     91         try:

/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first)
   1260     hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
   1261     return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
-> 1262                         dropout, train, bidirectional, batch_first)
   1263
   1264

/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in _generic_rnn(g, variant, input, initial_states, all_weights, has_biases, num_layers, dropout, train, bidirectional, batch_first, batch_sizes)
   1201             state_indices = i, i + 1
   1202         else:
-> 1203             weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
   1204             weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
   1205

/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in transform_weights(layer_index)
   1188         elif variant == 'GRU' or variant == 'LSTM':
   1189             weight_ih, weight_hh, bias_ih, bias_hh = \
-> 1190                 [reform_weights(g, w, hidden_size, reform_permutation) for w in layer_weights[layer_index]]
   1191         bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
   1192

ValueError: not enough values to unpack (expected 4, got 2)