shenweichen / DeepMatch

A deep matching model library for recommendations & advertising. It's easy to train models and to export representation vectors which can be used for ANN search.
https://deepmatch.readthedocs.io/en/latest/
Apache License 2.0
2.19k stars 525 forks source link

当对DSSM模型设置sample_weight会报错 #101

Open zhuchenxi opened 9 months ago

zhuchenxi commented 9 months ago

Describe the bug(问题描述) 当对DSSM模型设置sample_weight会报错,其中sample_weight是按照格式,和label一样大小的一个numpy的一维数组

To Reproduce(复现步骤) 运行代码: history = model.fit(train_model_input, train_label, batch_size=256, epochs=4, verbose=1, validation_split=0.0, sample_weight = sample_weights)

Operating environment(运行环境):

Additional context 对应结果: Traceback (most recent call last): File "dssm_qt_train_noid.py", line 93, in sample_weight = sample_weights) File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 727, in fit use_multiprocessing=use_multiprocessing) File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 675, in fit steps_name='steps_per_epoch') File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 394, in model_iteration batch_outs = f(ins_batch) File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 3476, in call run_metadata=self.run_metadata) File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1472, in call run_metadata_ptr) tensorflow.python.framework.errors_impl.InvalidArgumentError: Can not squeeze dim[0], expected a dimension of 1, got 256 [[{{node loss_1/in_batch_softmax_layer_loss/weighted_loss/Squeeze}}]]