XifengGuo / CapsNet-Keras

A Keras implementation of CapsNet in NIPS2017 paper "Dynamic Routing Between Capsules". Now test error = 0.34%.
MIT License
2.46k stars 652 forks source link

GPU version can't work with keras. #45

Closed clockzhong closed 6 years ago

clockzhong commented 6 years ago

Hi, Mr. Guo, I've found a problem when using the GPU version. The following is my steps: python capsulenet-multi-gpu.py --gpus 2

Using TensorFlow backend. Namespace(batch_size=300, debug=0, digit=5, epochs=50, gpus=2, lam_recon=0.392, lr=0.001, routings=3, save_dir='./result', shift_fraction=0.1, testing=False, weights=None)


Layer (type) Output Shape Param # Connected to

input_1 (InputLayer) (None, 28, 28, 1) 0


conv1 (Conv2D) (None, 20, 20, 256) 20992 input_1[0][0]


primarycap_conv2d (Conv2D) (None, 6, 6, 256) 5308672 conv1[0][0]


primarycap_reshape (Reshape) (None, 1152, 8) 0 primarycap_conv2d[0][0]


primarycap_squash (Lambda) (None, 1152, 8) 0 primarycap_reshape[0][0]


digitcaps (CapsuleLayer) (None, 10, 16) 1474560 primarycap_squash[0][0]


input_2 (InputLayer) (None, 10) 0


mask_1 (Mask) (None, 160) 0 digitcaps[0][0]
input_2[0][0]


capsnet (Length) (None, 10) 0 digitcaps[0][0]


decoder (Sequential) (None, 28, 28, 1) 1411344 mask_1[0][0]

Total params: 8,215,568 Trainable params: 8,215,568 Non-trainable params: 0


Traceback (most recent call last): File "capsulenet-multi-gpu.py", line 122, in plot_model(model, to_file=args.save_dir+'/model.png', show_shapes=True) File "/usr/local/lib/python2.7/dist-packages/keras/utils/vis_utils.py", line 131, in plot_model dot = model_to_dot(model, show_shapes, show_layer_names, rankdir) File "/usr/local/lib/python2.7/dist-packages/keras/utils/vis_utils.py", line 52, in model_to_dot _check_pydot() File "/usr/local/lib/python2.7/dist-packages/keras/utils/vis_utils.py", line 27, in _check_pydot raise ImportError('Failed to import pydot. You must install pydot' ImportError: Failed to import pydot. You must install pydot and graphviz for pydotprint to work.

I've installed pydot and graphviz already, and my keras version is:

keras.version '2.1.2'

Thanks!

Clock ZHONG

XifengGuo commented 6 years ago

@clockzhong Just comment out these two lines:

from keras.utils.vis_utils import plot_model

plot_model(model, to_file=args.save_dir+'/model.png', show_shapes=True)

clockzhong commented 6 years ago

这只是一种绕开问题的方法,真正的问题的root cause在keras的plot_model(里面,对吧?

XifengGuo commented 6 years ago

See #7 and #25 Some resources can be found in https://github.com/keras-team/keras I don't remember how I fixed this problem, but I believe you can find a solution by using Google or Baidu.