EscVM / Efficient-CapsNet

Official TensorFlow code for the paper "Efficient-CapsNet: Capsule Network with Self-Attention Routing".
https://www.nature.com/articles/s41598-021-93977-0
Apache License 2.0
267 stars 59 forks source link

Visualisation notebook throws UnimplementedError: Fused conv implementation does not support grouped convolutions for now #1

Closed msm1089 closed 3 years ago

msm1089 commented 3 years ago

I am getting an error when running model_test.evaluate(mnist_dataset.X_test, mnist_dataset.y_test) in the dynamic visualisations notebook. The first time I installed the requirements there were a couple of warnings, about incompatible versions. I tried installing again and everything was then OK.

The notebook runs fully on Colab, thanks, but I would like to get it working locally. Any ideas to try and fix this please?


UnimplementedError Traceback (most recent call last)

in ----> 1 model_test.evaluate(mnist_dataset.X_test, mnist_dataset.y_test) C:\DSAI\Dissertation\Efficient-CapsNet\models\model.py in evaluate(self, X_test, y_test) 94 acc = np.mean(acc) 95 else: ---> 96 y_pred, X_gen = self.model.predict(X_test) 97 acc = np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0] 98 test_error = 1 - acc ~\miniconda3\envs\dsai_py3.8\lib\site-packages\tensorflow\python\keras\engine\training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing) 1627 for step in data_handler.steps(): 1628 callbacks.on_predict_batch_begin(step) -> 1629 tmp_batch_outputs = self.predict_function(iterator) 1630 if data_handler.should_sync: 1631 context.async_wait() ~\miniconda3\envs\dsai_py3.8\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds) 826 tracing_count = self.experimental_get_tracing_count() 827 with trace.Trace(self._name) as tm: --> 828 result = self._call(*args, **kwds) 829 compiler = "xla" if self._experimental_compile else "nonXla" 830 new_tracing_count = self.experimental_get_tracing_count() ~\miniconda3\envs\dsai_py3.8\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds) 892 *args, **kwds) 893 # If we did not create any variables the trace we have is good enough. --> 894 return self._concrete_stateful_fn._call_flat( 895 filtered_flat_args, self._concrete_stateful_fn.captured_inputs) # pylint: disable=protected-access 896 ~\miniconda3\envs\dsai_py3.8\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager) 1916 and executing_eagerly): 1917 # No tape is watching; skip to running the function. -> 1918 return self._build_call_outputs(self._inference_function.call( 1919 ctx, args, cancellation_manager=cancellation_manager)) 1920 forward_backward = self._select_forward_and_backward_functions( ~\miniconda3\envs\dsai_py3.8\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager) 553 with _InterpolateFunctionError(self): 554 if cancellation_manager is None: --> 555 outputs = execute.execute( 556 str(self.signature.name), 557 num_outputs=self._num_outputs, ~\miniconda3\envs\dsai_py3.8\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 57 try: 58 ctx.ensure_initialized() ---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, 60 inputs, attrs, num_outputs) 61 except core._NotOkStatusException as e: UnimplementedError: Fused conv implementation does not support grouped convolutions for now. [[node Efficinet_CapsNet_Generator/Efficient_CapsNet/primary_caps_2/conv2d/BiasAdd (defined at C:\DSAI\Dissertation\Efficient-CapsNet\utils\layers.py:129) ]] [Op:__inference_predict_function_3999] Function call stack: predict_function
msm1089 commented 3 years ago

Turned out to be due to not finding my GPU. I fixed that and the notebook now runs fine :)

EscVM commented 3 years ago

Yes, the current version of TensorFlow hasn't the CPU kernel for that specific operation. We should add a hint to use the GPU runtime.

Thanks!