A deep matching model library for recommendations & advertising. It's easy to train models and to export representation vectors which can be used for ANN search.
Describe the bug(问题描述)
当对DSSM模型设置sample_weight会报错,其中sample_weight是按照格式,和label一样大小的一个numpy的一维数组
To Reproduce(复现步骤)
运行代码:
history = model.fit(train_model_input, train_label,
batch_size=256, epochs=4, verbose=1, validation_split=0.0,
sample_weight = sample_weights)
Operating environment(运行环境):
python version [3.7]
tensorflow version [1.15.0]
deepmatch version [0.3.1]
Additional context
对应结果:
Traceback (most recent call last):
File "dssm_qt_train_noid.py", line 93, in
sample_weight = sample_weights)
File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 727, in fit
use_multiprocessing=use_multiprocessing)
File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 675, in fit
steps_name='steps_per_epoch')
File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 394, in model_iteration
batch_outs = f(ins_batch)
File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 3476, in call
run_metadata=self.run_metadata)
File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1472, in call
run_metadata_ptr)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Can not squeeze dim[0], expected a dimension of 1, got 256
[[{{node loss_1/in_batch_softmax_layer_loss/weighted_loss/Squeeze}}]]
Describe the bug(问题描述) 当对DSSM模型设置sample_weight会报错,其中sample_weight是按照格式,和label一样大小的一个numpy的一维数组
To Reproduce(复现步骤) 运行代码: history = model.fit(train_model_input, train_label, batch_size=256, epochs=4, verbose=1, validation_split=0.0, sample_weight = sample_weights)
Operating environment(运行环境):
Additional context 对应结果: Traceback (most recent call last): File "dssm_qt_train_noid.py", line 93, in
sample_weight = sample_weights)
File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 727, in fit
use_multiprocessing=use_multiprocessing)
File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 675, in fit
steps_name='steps_per_epoch')
File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 394, in model_iteration
batch_outs = f(ins_batch)
File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 3476, in call
run_metadata=self.run_metadata)
File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1472, in call
run_metadata_ptr)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Can not squeeze dim[0], expected a dimension of 1, got 256
[[{{node loss_1/in_batch_softmax_layer_loss/weighted_loss/Squeeze}}]]