keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.97k stars 19.46k forks source link

keras model add layer created by k.stack in order to create stacked output #11940

Closed thebeancounter closed 5 years ago

thebeancounter commented 5 years ago

I am using keras with tf backend to create a siamese network, I am trying to create a custom loss function for triplet loss and need to pass it with multiple outputs in a one tensor that I can then split in the loss function in order to calculate the gradient.

I am trying to use what explained here as for how to use multiple outputs in a concatenated form, and use it in my code in the following way

input_layer = Input(shape=(784,))                               
a = Dense(100, activation="relu")(input_layer)                  
o = Dense(40, activation="relu")(a)                             
layer1 = Lambda(lambda x: K.expand_dims(x, axis=2))(o)          
layer2 = Lambda(lambda x: K.expand_dims(x, axis=2))(o)          
concat_layer = concatenate([layer1, layer2], axis=2)            

model = Model(input_layer, concat_layer)                        
model.compile(optimizer=SGD(), loss=triplet_loss_wrapper())     

(x_train, y_train), (x_test, y_test) = mnist.load_data()        
x_test = x_test.reshape(x_test.shape[0], 784)                   

model.fit(x_test, [1] * len(x_test))

I get the following error

(np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) ValueError: Cannot feed value of shape (32, 1) for Tensor 'concatenate_1_target:0', which has shape '(?, ?, ?)'

please also see my connected SO question

thebeancounter commented 5 years ago

update! this is running on a google colab but giving this error when running locally on a conda env. this is the list of packages in the conda env

packages in environment at /home/shai/anaconda3/envs/siamese:

#

_license 1.1 py36_1
alabaster 0.7.10 py36_0
anaconda custom py36_0
anaconda-client 1.6.3 py36_0
anaconda-navigator 1.6.4 py36_0
anaconda-project 0.6.0 py36_0
asn1crypto 0.22.0 py36_0
astroid 1.5.3 py36_0
astropy 2.0.1 np113py36_0
babel 2.5.0 py36_0
backports 1.0 py36_0
backports.weakref 1.0rc1 py36_0
beautifulsoup4 4.6.0 py36_0
bitarray 0.8.1 py36_0
bkcharts 0.2 py36_0
blas 1.0 mkl
blaze 0.10.1 py36_0
bleach 1.5.0 py36_0
bokeh 0.12.7 py36_0
boto 2.48.0 py36_0
bottleneck 1.2.1 np113py36_0
cairo 1.14.8 0
certifi 2016.2.28 py36_0
cffi 1.10.0 py36_0
chardet 3.0.4 py36_0
click 6.7 py36_0
cloudpickle 0.4.0 py36_0
clyent 1.2.2 py36_0
colorama 0.3.9 py36_0
contextlib2 0.5.5 py36_0
cryptography 1.8.1 py36_0
cudatoolkit 8.0 3
cudnn 6.0.21 cuda8.0_0
curl 7.52.1 0
cycler 0.10.0 py36_0
cython 0.26 py36_0
cytoolz 0.8.2 py36_0
dask 0.15.2 py36_0
datashape 0.5.4 py36_0
dbus 1.10.20 0
decorator 4.1.2 py36_0
distributed 1.18.1 py36_0
docutils 0.14 py36_0
entrypoints 0.2.3 py36_0
et_xmlfile 1.0.1 py36_0
expat 2.1.0 0
fastcache 1.0.2 py36_1
flask 0.12.2 py36_0
flask-cors 3.0.3 py36_0
fontconfig 2.12.1 3
freetype 2.5.5 2
get_terminal_size 1.0.0 py36_0
gevent 1.2.2 py36_0
glib 2.50.2 1
greenlet 0.4.12 py36_0
gst-plugins-base 1.8.0 0
gstreamer 1.8.0 0
h5py 2.7.0 np113py36_0
harfbuzz 0.9.39 2
hdf5 1.8.17 2
heapdict 1.0.0 py36_1
html5lib 0.9999999 py36_0
icu 54.1 0
idna 2.6 py36_0
imagesize 0.7.1 py36_0
ipykernel 4.6.1 py36_0
ipython 6.1.0 py36_0
ipython_genutils 0.2.0 py36_0
ipywidgets 6.0.0 py36_0
isort 4.2.15 py36_0
itsdangerous 0.24 py36_0
jbig 2.1 0
jdcal 1.3 py36_0
jedi 0.10.2 py36_2
jinja2 2.9.6 py36_0
jpeg 9b 0
jsonschema 2.6.0 py36_0
jupyter 1.0.0 py36_3
jupyter_client 5.1.0 py36_0
jupyter_console 5.2.0 py36_0
jupyter_core 4.3.0 py36_0
keras 2.0.5 py36_0
lazy-object-proxy 1.3.1 py36_0
libffi 3.2.1 1
libgcc 5.2.0 0
libgfortran 3.0.0 1
libgpuarray 0.6.9 0
libiconv 1.14 0
libpng 1.6.30 1
libprotobuf 3.4.0 0
libsodium 1.0.10 0
libtiff 4.0.6 3
libtool 2.4.2 0
libxcb 1.12 1
libxml2 2.9.4 0
libxslt 1.1.29 0
llvmlite 0.20.0 py36_0
locket 0.2.0 py36_1
lxml 3.8.0 py36_0
mako 1.0.6 py36_0
markdown 2.6.9 py36_0
markupsafe 1.0 py36_0
matplotlib 2.0.2 np113py36_0
mistune 0.7.4 py36_0
mkl 2017.0.3 0
mkl-service 1.1.2 py36_3
mpmath 0.19 py36_1
msgpack-python 0.4.8 py36_0
multipledispatch 0.4.9 py36_0
navigator-updater 0.1.0 py36_0
nbconvert 5.2.1 py36_0
nbformat 4.4.0 py36_0
nccl 1.3.4 cuda8.0_1
networkx 1.11 py36_0
nltk 3.2.4 py36_0
nose 1.3.7 py36_1
notebook 5.0.0 py36_0
numba 0.35.0 np113py36_0
numexpr 2.6.2 np113py36_0
numpy 1.13.1 py36_0
numpydoc 0.7.0 py36_0
odo 0.5.1 py36_0
olefile 0.44 py36_0
openpyxl 2.4.8 py36_0
openssl 1.0.2l 0
packaging 16.8 py36_0
pandas 0.20.3 py36_0
pandocfilters 1.4.2 py36_0
pango 1.40.3 1
partd 0.3.8 py36_0
path.py 10.3.1 py36_0
pathlib2 2.3.0 py36_0
patsy 0.4.1 py36_0
pcre 8.39 1
pep8 1.7.0 py36_0
pexpect 4.2.1 py36_0
pickleshare 0.7.4 py36_0
pillow 4.2.1 py36_0
pip 9.0.1 py36_1
pixman 0.34.0 0
ply 3.10 py36_0
prompt_toolkit 1.0.15 py36_0
protobuf 3.4.0 py36_0
psutil 5.2.2 py36_0
ptyprocess 0.5.2 py36_0
py 1.4.33 py36_0
pycodestyle 2.3.1 py36_0
pycosat 0.6.2 py36_0
pycparser 2.18 py36_0
pycrypto 2.6.1 py36_6
pycurl 7.43.0 py36_2
pyflakes 1.6.0 py36_0
pygments 2.2.0 py36_0
pygpu 0.6.9 py36_0
pylint 1.7.2 py36_0
pyodbc 4.0.17 py36_0
pyopenssl 17.0.0 py36_0
pyparsing 2.2.0 py36_0
pyqt 5.6.0 py36_2
pytables 3.4.2 np113py36_0
pytest 3.2.1 py36_0
python 3.6.2 0
python-dateutil 2.6.1 py36_0
pytorch 0.1.12 py36cuda8.0cudnn6.0_1
pytz 2017.2 py36_0
pywavelets 0.5.2 np113py36_0
pyyaml 3.12 py36_0
pyzmq 16.0.2 py36_0
qt 5.6.2 4
qtawesome 0.4.4 py36_0
qtconsole 4.3.1 py36_0
qtpy 1.3.1 py36_0
readline 6.2 2
requests 2.14.2 py36_0
rope 0.9.4 py36_1
ruamel_yaml 0.11.14 py36_1
scikit-image 0.13.0 np113py36_0
scikit-learn 0.19.0 np113py36_0
scipy 0.19.1 np113py36_0
seaborn 0.8 py36_0
setuptools 36.4.0 py36_1
simplegeneric 0.8.1 py36_1
singledispatch 3.4.0.3 py36_0
sip 4.18 py36_0
six 1.10.0 py36_0
snowballstemmer 1.2.1 py36_0
sortedcollections 0.5.3 py36_0
sortedcontainers 1.5.7 py36_0
sphinx 1.6.3 py36_0
sphinxcontrib 1.0 py36_0
sphinxcontrib-websupport 1.0.1 py36_0
spyder 3.2.3 py36_0
sqlalchemy 1.1.13 py36_0
sqlite 3.13.0 0
statsmodels 0.8.0 np113py36_0
sympy 1.1.1 py36_0
tblib 1.3.2 py36_0
tensorflow 1.3.0 0
tensorflow-base 1.3.0 py36h5293eaa_1
tensorflow-tensorboard 0.1.5 py36_0
terminado 0.6 py36_0
testpath 0.3.1 py36_0
theano 0.9.0 py36_0
tk 8.5.18 0
toolz 0.8.2 py36_0
torchvision 0.1.8 py36_0
tornado 4.5.2 py36_0
traitlets 4.3.2 py36_0
unicodecsv 0.14.1 py36_0
unixodbc 2.3.4 0
wcwidth 0.1.7 py36_0
werkzeug 0.12.2 py36_0
wheel 0.29.0 py36_0
widgetsnbextension 3.0.2 py36_0
wrapt 1.10.11 py36_0
xlrd 1.1.0 py36_0
xlsxwriter 0.9.8 py36_0
xlwt 1.3.0 py36_0
xz 5.2.3 0
yaml 0.1.6 0
zeromq 4.1.5 0
zict 0.1.2 py36_0
zlib 1.2.11 0

thebeancounter commented 5 years ago

update: just updated to keras 2.2.4, still getting the same error.

here is my updated list of packages:

_license 1.1 py36_1
alabaster 0.7.10 py36_0
anaconda custom py36_0
anaconda-client 1.6.3 py36_0
anaconda-navigator 1.6.4 py36_0
anaconda-project 0.6.0 py36_0
asn1crypto 0.22.0 py36_0
astroid 1.5.3 py36_0
astropy 2.0.1 np113py36_0
babel 2.5.0 py36_0
backports 1.0 py36_0
backports.weakref 1.0rc1 py36_0
beautifulsoup4 4.6.0 py36_0
bitarray 0.8.1 py36_0
bkcharts 0.2 py36_0
blas 1.0 mkl
blaze 0.10.1 py36_0
bleach 1.5.0 py36_0
bokeh 0.12.7 py36_0
boto 2.48.0 py36_0
bottleneck 1.2.1 np113py36_0
cairo 1.14.8 0
certifi 2016.2.28 py36_0
cffi 1.10.0 py36_0
chardet 3.0.4 py36_0
click 6.7 py36_0
cloudpickle 0.4.0 py36_0
clyent 1.2.2 py36_0
colorama 0.3.9 py36_0
contextlib2 0.5.5 py36_0
cryptography 1.8.1 py36_0
cudatoolkit 8.0 3
cudnn 6.0.21 cuda8.0_0
curl 7.52.1 0
cycler 0.10.0 py36_0
cython 0.26 py36_0
cytoolz 0.8.2 py36_0
dask 0.15.2 py36_0
datashape 0.5.4 py36_0
dbus 1.10.20 0
decorator 4.1.2 py36_0
distributed 1.18.1 py36_0
docutils 0.14 py36_0
entrypoints 0.2.3 py36_0
et_xmlfile 1.0.1 py36_0
expat 2.1.0 0
fastcache 1.0.2 py36_1
flask 0.12.2 py36_0
flask-cors 3.0.3 py36_0
fontconfig 2.12.1 3
freetype 2.5.5 2
get_terminal_size 1.0.0 py36_0
gevent 1.2.2 py36_0
glib 2.50.2 1
greenlet 0.4.12 py36_0
gst-plugins-base 1.8.0 0
gstreamer 1.8.0 0
h5py 2.7.0 np113py36_0
harfbuzz 0.9.39 2
hdf5 1.8.17 2
heapdict 1.0.0 py36_1
html5lib 0.9999999 py36_0
icu 54.1 0
idna 2.6 py36_0
imagesize 0.7.1 py36_0
ipykernel 4.6.1 py36_0
ipython 6.1.0 py36_0
ipython_genutils 0.2.0 py36_0
ipywidgets 6.0.0 py36_0
isort 4.2.15 py36_0
itsdangerous 0.24 py36_0
jbig 2.1 0
jdcal 1.3 py36_0
jedi 0.10.2 py36_2
jinja2 2.9.6 py36_0
jpeg 9b 0
jsonschema 2.6.0 py36_0
jupyter 1.0.0 py36_3
jupyter_client 5.1.0 py36_0
jupyter_console 5.2.0 py36_0
jupyter_core 4.3.0 py36_0
keras 2.2.4 py36_0 conda-forge keras-applications 1.0.4 py_1 conda-forge keras-preprocessing 1.0.2 py_1 conda-forge lazy-object-proxy 1.3.1 py36_0
libffi 3.2.1 1
libgcc 5.2.0 0
libgfortran 3.0.0 1
libgpuarray 0.6.9 0
libiconv 1.14 0
libpng 1.6.30 1
libprotobuf 3.4.0 0
libsodium 1.0.10 0
libtiff 4.0.6 3
libtool 2.4.2 0
libxcb 1.12 1
libxml2 2.9.4 0
libxslt 1.1.29 0
llvmlite 0.20.0 py36_0
locket 0.2.0 py36_1
lxml 3.8.0 py36_0
mako 1.0.6 py36_0
markdown 2.6.9 py36_0
markupsafe 1.0 py36_0
matplotlib 2.0.2 np113py36_0
mistune 0.7.4 py36_0
mkl 2017.0.3 0
mkl-service 1.1.2 py36_3
mpmath 0.19 py36_1
msgpack-python 0.4.8 py36_0
multipledispatch 0.4.9 py36_0
navigator-updater 0.1.0 py36_0
nbconvert 5.2.1 py36_0
nbformat 4.4.0 py36_0
nccl 1.3.4 cuda8.0_1
networkx 1.11 py36_0
nltk 3.2.4 py36_0
nose 1.3.7 py36_1
notebook 5.0.0 py36_0
numba 0.35.0 np113py36_0
numexpr 2.6.2 np113py36_0
numpy 1.13.1 py36_0
numpydoc 0.7.0 py36_0
odo 0.5.1 py36_0
olefile 0.44 py36_0
openpyxl 2.4.8 py36_0
openssl 1.0.2l 0
packaging 16.8 py36_0
pandas 0.20.3 py36_0
pandocfilters 1.4.2 py36_0
pango 1.40.3 1
partd 0.3.8 py36_0
path.py 10.3.1 py36_0
pathlib2 2.3.0 py36_0
patsy 0.4.1 py36_0
pcre 8.39 1
pep8 1.7.0 py36_0
pexpect 4.2.1 py36_0
pickleshare 0.7.4 py36_0
pillow 4.2.1 py36_0
pip 9.0.1 py36_1
pixman 0.34.0 0
ply 3.10 py36_0
prompt_toolkit 1.0.15 py36_0
protobuf 3.4.0 py36_0
psutil 5.2.2 py36_0
ptyprocess 0.5.2 py36_0
py 1.4.33 py36_0
pycodestyle 2.3.1 py36_0
pycosat 0.6.2 py36_0
pycparser 2.18 py36_0
pycrypto 2.6.1 py36_6
pycurl 7.43.0 py36_2
pyflakes 1.6.0 py36_0
pygments 2.2.0 py36_0
pygpu 0.6.9 py36_0
pylint 1.7.2 py36_0
pyodbc 4.0.17 py36_0
pyopenssl 17.0.0 py36_0
pyparsing 2.2.0 py36_0
pyqt 5.6.0 py36_2
pytables 3.4.2 np113py36_0
pytest 3.2.1 py36_0
python 3.6.2 0
python-dateutil 2.6.1 py36_0
pytorch 0.1.12 py36cuda8.0cudnn6.0_1
pytz 2017.2 py36_0
pywavelets 0.5.2 np113py36_0
pyyaml 3.12 py36_0
pyzmq 16.0.2 py36_0
qt 5.6.2 4
qtawesome 0.4.4 py36_0
qtconsole 4.3.1 py36_0
qtpy 1.3.1 py36_0
readline 6.2 2
requests 2.14.2 py36_0
rope 0.9.4 py36_1
ruamel_yaml 0.11.14 py36_1
scikit-image 0.13.0 np113py36_0
scikit-learn 0.19.0 np113py36_0
scipy 0.19.1 np113py36_0
seaborn 0.8 py36_0
setuptools 36.4.0 py36_1
simplegeneric 0.8.1 py36_1
singledispatch 3.4.0.3 py36_0
sip 4.18 py36_0
six 1.10.0 py36_0
snowballstemmer 1.2.1 py36_0
sortedcollections 0.5.3 py36_0
sortedcontainers 1.5.7 py36_0
sphinx 1.6.3 py36_0
sphinxcontrib 1.0 py36_0
sphinxcontrib-websupport 1.0.1 py36_0
spyder 3.2.3 py36_0
sqlalchemy 1.1.13 py36_0
sqlite 3.13.0 0
statsmodels 0.8.0 np113py36_0
sympy 1.1.1 py36_0
tblib 1.3.2 py36_0
tensorflow 1.3.0 0
tensorflow-base 1.3.0 py36h5293eaa_1
tensorflow-tensorboard 0.1.5 py36_0
terminado 0.6 py36_0
testpath 0.3.1 py36_0
theano 0.9.0 py36_0
tk 8.5.18 0
toolz 0.8.2 py36_0
torchvision 0.1.8 py36_0
tornado 4.5.2 py36_0
traitlets 4.3.2 py36_0
unicodecsv 0.14.1 py36_0
unixodbc 2.3.4 0
wcwidth 0.1.7 py36_0
werkzeug 0.12.2 py36_0
wheel 0.29.0 py36_0
widgetsnbextension 3.0.2 py36_0
wrapt 1.10.11 py36_0
xlrd 1.1.0 py36_0
xlsxwriter 0.9.8 py36_0
xlwt 1.3.0 py36_0
xz 5.2.3 0
yaml 0.1.6 0
zeromq 4.1.5 0
zict 0.1.2 py36_0
zlib 1.2.11 0

thebeancounter commented 5 years ago

update, I validated that keras on google colab is running with the same version. using 2.2.4 did not solve the issue,

I further checked, tf was 1.3 update to 1.10 using conda install -c conda-forge tensorflow

solved the issue.

Thanks all

jvishnuvardhan commented 5 years ago

As it was resolved, I will close this issue. Thanks!