XifengGuo / CapsNet-Keras

A Keras implementation of CapsNet in NIPS2017 paper "Dynamic Routing Between Capsules". Now test error = 0.34%.
MIT License
2.46k stars 652 forks source link

Invalid shape on manipulate_latent [TF 2.2 Branch] #110

Open CoffeeStraw opened 4 years ago

CoffeeStraw commented 4 years ago

Hello, I was trying the new TF 2.2 branch. Training and testing seems to work just fine, but I've encountered an error in manipulate_latent function.

I couldn't figure out a solution since I've never seen this usage for predict function of Keras, so I thought that opening an issue would be wiser.

Here's the traceback:

------------------------------Begin: manipulate------------------------------ Traceback (most recent call last): File "capsulenet.py", line 261, in manipulate_latent(manipulate_model, (x_test, y_test), args) File "capsulenet.py", line 184, in manipulate_latent x_recon = model.predict([x, y, tmp]) File "D:\CapsNet-Keras\env\lib\site-packages\tensorflow\python\keras\engine\training.py", line 88, in _method_wrapper return method(self, *args, *kwargs) File "D:\CapsNet-Keras\env\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1268, in predict tmp_batch_outputs = predict_function(iterator) File "D:\CapsNet-Keras\env\lib\site-packages\tensorflow\python\eager\deffunction.py", line 580, in __call_\ result = self._call(args, **kwds) File "D:\CapsNet-Keras\env\lib\site-packages\tensorflow\python\eager\def_function.py", line 650, in _call return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access File "D:\CapsNet-Keras\env\lib\site-packages\tensorflow\python\eager\function.py", line 1661, in _filtered_call return self._call_flat( File "D:\CapsNet-Keras\env\lib\site-packages\tensorflow\python\eager\function.py", line 1745, in _call_flat return self._build_call_outputs(self._inference_function.call( File "D:\CapsNet-Keras\env\lib\site-packages\tensorflow\python\eager\function.py", line 593, in call outputs = execute.execute( File "D:\CapsNet-Keras\env\lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 9216 values, but the requested shape requires a multiple of 800
[[node model_2/primarycap_reshape/Reshape (defined at capsulenet.py:184) ]] [Op:__inference_predict_function_878]

bklooste commented 4 years ago

I have a similar error creating an issue for it basically the batch size if built into the model.

Ysx2mina commented 3 years ago

same

Hallahallan commented 2 weeks ago

Same error here, did anyone find a solution to this?