I've added batchnorm layer to your dummy example and it crushed
the code:
from keras import backend as K
K.set_image_data_format('channels_first')
from pytorch2keras.converter import pytorch_to_keras
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
class TestConv2d(nn.Module):
def __init__(self, inp=10, out=16, kernel_size=3):
super(TestConv2d, self).__init__()
self.conv2d = nn.Conv2d(inp, out, stride=1, kernel_size=kernel_size, bias=True)
self.bn_1 = nn.BatchNorm2d(num_features = out)
def forward(self, x):
x = self.conv2d(x)
x = self.bn_1(x)
return x
model = TestConv2d()
input_np = np.random.uniform(0, 1, (1, 10, 32, 32))
input_var = Variable(torch.FloatTensor(input_np))
# we should specify shape of the input tensor
k_model = pytorch_to_keras(model, input_var, [(10, 32, 32,)], verbose=True)
the error:
graph(%0 : Float(1, 10, 32, 32)
%1 : Float(16, 10, 3, 3)
%2 : Float(16)
%3 : Float(16)
%4 : Float(16)
%5 : Float(16)
%6 : Float(16)
%7 : Long()) {
%8 : Float(1, 16, 30, 30) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%0, %1, %2), scope: TestConv2d/Conv2d[conv2d]
%9 : Float(1, 16, 30, 30) = onnx::BatchNormalization[epsilon=1e-05, is_test=1, momentum=1](%8, %3, %4, %5, %6), scope: TestConv2d/BatchNorm2d[bn_1]
return (%9);
}
Graph inputs: ['0', '1', '2', '3', '4', '5', '6', '7']
Graph outputs: ['9']
State dict: ['conv2d.weight', 'conv2d.bias', 'bn_1.weight', 'bn_1.bias', 'bn_1.running_mean', 'bn_1.running_var', 'bn_1.num_batches_tracked']
____
graph node: TestConv2d/Conv2d[conv2d]
node id: 8
type: onnx::Conv
inputs: ['0', '1', '2']
outputs: ['TestConv2d/Conv2d[conv2d]']
name in state_dict: conv2d
attrs: {'dilations': [1, 1], 'group': 1, 'kernel_shape': [3, 3], 'pads': [0, 0, 0, 0], 'strides': [1, 1]}
is_terminal: False
Converting convolution ...
____
graph node: TestConv2d/BatchNorm2d[bn_1]
node id: 9
type: onnx::BatchNormalization
inputs: ['8', '3', '4', '5', '6']
outputs: ['TestConv2d/BatchNorm2d[bn_1]']
name in state_dict: bn_1
attrs: {'epsilon': 1e-05, 'is_test': 1, 'momentum': 1.0}
is_terminal: True
Converting batchnorm ...
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
~\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1575 try:
-> 1576 c_op = c_api.TF_FinishOperation(op_desc)
1577 except errors.InvalidArgumentError as e:
InvalidArgumentError: Shape must be rank 1 but is rank 0 for 'bn_10.9351925092536912/cond/Reshape_4' (op: 'Reshape') with input shapes: [1,16,1,1], [].
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-1-7f985261f348> in <module>()
24
25 # we should specify shape of the input tensor
---> 26 k_model = pytorch_to_keras(model, input_var, [(10, 32, 32,)], verbose=True)
~\Anaconda3\lib\site-packages\pytorch2keras\converter.py in pytorch_to_keras(model, args, input_shapes, change_ordering, training, verbose, names)
313 node_input_names,
314 layers, state_dict,
--> 315 names
316 )
317 if node_id in graph_outputs:
~\Anaconda3\lib\site-packages\pytorch2keras\normalization_layers.py in convert_batchnorm(params, w_name, scope_name, inputs, layers, weights, names)
59 name=tf_name
60 )
---> 61 layers[scope_name] = bn(layers[inputs[0]])
62
63
~\Anaconda3\lib\site-packages\keras\engine\base_layer.py in __call__(self, inputs, **kwargs)
455 # Actually call the layer,
456 # collecting output(s), mask(s), and shape(s).
--> 457 output = self.call(inputs, **kwargs)
458 output_mask = self.compute_mask(inputs, previous_mask)
459
~\Anaconda3\lib\site-packages\keras\layers\normalization.py in call(self, inputs, training)
204 return K.in_train_phase(normed_training,
205 normalize_inference,
--> 206 training=training)
207
208 def get_config(self):
~\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py in in_train_phase(x, alt, training)
3121
3122 # else: assume learning phase is a placeholder tensor.
-> 3123 x = switch(training, x, alt)
3124 if uses_learning_phase:
3125 x._uses_learning_phase = True
~\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py in switch(condition, then_expression, else_expression)
3056 x = tf.cond(condition,
3057 then_expression_fn,
-> 3058 else_expression_fn)
3059 else:
3060 # tf.where needs its condition tensor
~\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
452 'in a future version' if date is None else ('after %s' % date),
453 instructions)
--> 454 return func(*args, **kwargs)
455 return tf_decorator.make_decorator(func, new_func, 'deprecated',
456 _add_deprecated_arg_notice_to_docstring(
~\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in cond(pred, true_fn, false_fn, strict, name, fn1, fn2)
2055 context_f = CondContext(pred, pivot_2, branch=0)
2056 context_f.Enter()
-> 2057 orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
2058 if orig_res_f is None:
2059 raise ValueError("false_fn must have a return value.")
~\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in BuildCondBranch(self, fn)
1893 """Add the subgraph defined by fn() to the graph."""
1894 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
-> 1895 original_result = fn()
1896 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
1897 if len(post_summaries) > len(pre_summaries):
~\Anaconda3\lib\site-packages\keras\layers\normalization.py in normalize_inference()
165 broadcast_gamma,
166 axis=self.axis,
--> 167 epsilon=self.epsilon)
168 else:
169 return K.batch_normalization(
~\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py in batch_normalization(x, mean, var, beta, gamma, axis, epsilon)
1906 # so it may have extra axes with 1, it is not needed and should be removed
1907 if ndim(mean) > 1:
-> 1908 mean = tf.reshape(mean, (-1))
1909 if ndim(var) > 1:
1910 var = tf.reshape(var, (-1))
~\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_array_ops.py in reshape(tensor, shape, name)
7432 if _ctx is None or not _ctx._eager_context.is_eager:
7433 _, _, _op = _op_def_lib._apply_op_helper(
-> 7434 "Reshape", tensor=tensor, shape=shape, name=name)
7435 _result = _op.outputs[:]
7436 _inputs_flat = _op.inputs
~\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
785 op = g.create_op(op_type_name, inputs, output_types, name=scope,
786 input_types=input_types, attrs=attr_protos,
--> 787 op_def=op_def)
788 return output_structure, op_def.is_stateful, op
789
~\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
452 'in a future version' if date is None else ('after %s' % date),
453 instructions)
--> 454 return func(*args, **kwargs)
455 return tf_decorator.make_decorator(func, new_func, 'deprecated',
456 _add_deprecated_arg_notice_to_docstring(
~\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in create_op(***failed resolving arguments***)
3153 input_types=input_types,
3154 original_op=self._default_original_op,
-> 3155 op_def=op_def)
3156 self._create_op_helper(ret, compute_device=compute_device)
3157 return ret
~\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
1729 op_def, inputs, node_def.attr)
1730 self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1731 control_input_ops)
1732
1733 # Initialize self._outputs.
~\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1577 except errors.InvalidArgumentError as e:
1578 # Convert to ValueError for backwards compatibility.
-> 1579 raise ValueError(str(e))
1580
1581 return c_op
ValueError: Shape must be rank 1 but is rank 0 for 'bn_10.9351925092536912/cond/Reshape_4' (op: 'Reshape') with input shapes: [1,16,1,1], [].
I've added batchnorm layer to your dummy example and it crushed
the code:
the error: