Leavingseason / xDeepFM

743 stars 220 forks source link

Bug for `_build_embedding` in class ExtremeDeepFMModel #1

Closed wenruij closed 6 years ago

wenruij commented 6 years ago

@Leavingseason Seems there's a bug for function _build_embedding:

def _build_embedding(self, hparams):
    fm_sparse_index = tf.SparseTensor(self.iterator.dnn_feat_indices,
                                      self.iterator.dnn_feat_values,
                                      self.iterator.dnn_feat_shape)
    fm_sparse_weight = tf.SparseTensor(self.iterator.dnn_feat_indices,
                                       self.iterator.dnn_feat_weights,
                                       self.iterator.dnn_feat_shape)
    w_fm_nn_input_orgin = tf.nn.embedding_lookup_sparse(self.embedding,
                                                        fm_sparse_index,
                                                        fm_sparse_weight,
                                                        combiner="sum")
    embedding = tf.reshape(w_fm_nn_input_orgin, [-1, hparams.dim * hparams.FIELD_COUNT])
    embedding_size = hparams.FIELD_COUNT * hparams.dim
    return embedding, embedding_size

You do a reshape after the tf.nn.embedding_lookup_sparse, let me imitate a simple example:

import tensorflow as tf
import numpy as np
dim=3
FIELD_COUNT=4
batch_size=3
ids = tf.SparseTensor(indices=[[0,0], [0,3], [1,1],[2,1]], values=[1, 3, 6, 3], dense_shape=[batch_size,FIELD_COUNT])
sp_weights = tf.SparseTensor(indices=[[0,0], [0,3], [1,1],[2,1]], values=[1,1,1,1], dense_shape=[batch_size,FIELD_COUNT])
params = tf.constant([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[10,20,30],[40,50,60],[70,80,90]])
embed = tf.nn.embedding_lookup_sparse(params, ids, sp_weights, combiner="sum")
sess = tf.Session()
sess.run(embed)

# array([[14, 16, 18],
#       [70, 80, 90],
#       [10, 11, 12]], dtype=int32)

embedding = tf.reshape(embed, [-1, dim * FIELD_COUNT])
sess.run(embedding) 

We will get the following errors when running embedding with tf.Session:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 9 values, but the requested shape requires a multiple of 12
     [[Node: Reshape_3 = Reshape[T=DT_INT32, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](embedding_lookup_sparse_2, Reshape_3/shape)]]
Leavingseason commented 6 years ago

Hi wenruij, that is not a bug. We have some constraints on the format of input data. Basically, it's a field-wise format, each line should contain exactly FIELD_CNT fields, while each field can contain 1 or more features. I am thinking about that I need to provide a more detailed descriptive file to illustrate this data format.

wenruij commented 6 years ago

@Leavingseason Thanks for your reply. Looking forward to your deep learning based factorization toolkit and all mentioned models.