geffy / tffm

TensorFlow implementation of an arbitrary order Factorization Machine
MIT License
780 stars 176 forks source link

Upgrade tensorflow to v1.0, tensor shape issue in pow_wrapper #20

Closed Vimos closed 7 years ago

Vimos commented 7 years ago

After upgrading tensorflow to tensorflow-1.0.0-cp27-cp27mu-manylinux1_x86_64.whl, the original code runs into error

Traceback (most recent call last):
  File "training.py", line 65, in <module>
    obj.training()
  File "/data/home/vimos/Public/git/github/hotel-revenue/revenueml/revenueml/train/split.py", line 264, in training
    getattr(self.trainer, "{}_train".format(method))(**ins))
  File "/data/home/vimos/Public/git/github/hotel-revenue/revenueml/revenueml/train/trainer.py", line 103, in tffm_train
    fm.fit(self.x_train, y_train, self.punishment, n_epochs=n_epochs, show_progress=True)
  File "/data/home/vimos/Public/git/github/hotel-revenue/revenueml/revenueml/tffm/base.py", line 247, in fit
    self.core.build_graph()
  File "/data/home/vimos/Public/git/github/hotel-revenue/revenueml/revenueml/tffm/core.py", line 214, in build_graph
    self.init_main_block()
  File "/data/home/vimos/Public/git/github/hotel-revenue/revenueml/revenueml/tffm/core.py", line 180, in init_main_block
    self.pow_matmul(i, in_pows[pow_idx]),
  File "/data/home/vimos/Public/git/github/hotel-revenue/revenueml/revenueml/tffm/core.py", line 122, in pow_matmul
    x_pow = pow_wrapper(self.train_x, pow, self.input_type)
  File "/data/home/vimos/Public/git/github/hotel-revenue/revenueml/revenueml/tffm/core.py", line 261, in pow_wrapper
    return tf.SparseTensor(X.indices, tf.pow(X.values, p), X.shape)
AttributeError: 'SparseTensor' object has no attribute 'shape'

I fixed this using

--- a/revenueml/tffm/core.py
+++ b/revenueml/tffm/core.py
@@ -258,6 +258,6 @@ def pow_wrapper(X, p, optype):
     if optype == 'dense':
         return tf.pow(X, p)
     elif optype == 'sparse':
-        return tf.SparseTensor(X.indices, tf.pow(X.values, p), X.shape)
+        return tf.SparseTensor(X.indices, tf.pow(X.values, p), X.dense_shape)
     else:
         raise NameError('Unknown input type in pow_wrapper')

Hope this fix is right and be helpful to others.

geffy commented 7 years ago

This fix is right. I will add it to the code soon (with other updates)

geffy commented 7 years ago

fixed & merged