manicman1999 / StyleGAN2-Tensorflow-2.0

StyleGAN 2 in Tensorflow 2.0
MIT License
486 stars 112 forks source link

conv2d_mod/Conv2D NCHW not implemented #12

Open xiaoliangbai opened 3 years ago

xiaoliangbai commented 3 years ago
generated_images = self.GAN.GM.predict(n1 + [n2], batch_size = BATCH_SIZE)

File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 909, in predict use_multiprocessing=use_multiprocessing) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 722, in predict callbacks=callbacks) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 393, in model_iteration batch_outs = f(ins_batch) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py", line 3740, in call outputs = self._graph_fn(*converted_inputs) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1081, in call return self._call_impl(args, kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1121, in _call_impl return self._call_flat(args, self.captured_inputs, cancellation_manager) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat ctx, args, cancellation_manager=cancellation_manager) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 511, in call ctx=ctx) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute six.raise_from(core._status_to_exception(e.code, message), None) File "", line 3, in raise_from tensorflow.python.framework.errors_impl.UnimplementedError: The Conv2D op currently only supports the NHWC tensor format on the CPU. The op was given the format: NCHW [[node model_1/conv2d_mod/Conv2D (defined at /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1751) ]] [Op:__inference_keras_scratch_graph_11413]

Function call stack: keras_scratch_graph

Seems conv2d does not take NCHW data format. I tried to force to run on gpu (with tf.device('/gpu:1'):...), it did not work. I also tried different tf versions (2.0, 2.3), even with docker image for tf2.0, all got into the same issue.

Anyone knows how to get around this issue? Thanks

anthonyivol commented 3 years ago

It is because it runs on CPU, try batch_size = 1, and in conv_mod.py :

# add this
x = tf.transpose(x, [0, 2, 3, 1])

# change NCHW to NHWC
x = tf.nn.conv2d(x, w, strides=self.strides, padding="SAME", data_format="NHWC")

# add this
x = tf.transpose(x, [0, 3, 1, 2])
xiaoliangbai commented 3 years ago

Thanks Anthony, your solution works. I thought weights also need to transpose axis in_chan to match with activation data format, turns out it doesn't.