Open millmi17 opened 3 years ago
Thanks for reporting this bug. We had not handled the edge case gracefully. Here your training set size is much less than the batch count. So it treats each batch as size of 1. So we are getting this error. I will mark this as a bug. You can try with a smaller batch count - say 3 or 4
I am getting the below error when using model.fit(X_train, early_stopping = False). To get X_train I loaded in a csv using the load_from_csv function which outputs below. I am new to github so if you need anything else to help try and fix it let me know array([['bill ', 'takes', 'calc'], ['bill ', 'is a ', 'person'], ['fred ', 'takes ', 'eng'], ['fred ', 'takes', 'chem'], ['chem ', 'located in ', 'pike'], ['pike ', 'is a ', 'building'], ['calc', 'located in ', 'smith'], ['calc ', 'is a ', 'building'], ['fred ', 'is a ', 'person']], dtype=object)
This was just a quick thing I put together to see if I could get this to work with ampligraph. I then use the below code to get the train test split. I follow the example code they had for complex E. The info is shown below
X_train, X_test = train_test_split_no_unseen(x, test_size=1) X_train[1] array(['bill ', 'takes', 'calc'], dtype=object)
model = ComplEx(batches_count=100, seed=0, epochs=1, k=150, eta=5, optimizer='adam', optimizer_params={'lr':1e-3}, loss='multiclass_nll', regularizer='LP', regularizer_params={'p':3, 'lambda':1e-5}, verbose=True)
tf.logging.set_verbosity(tf.logging.ERROR)
model.fit(X_train, early_stopping = False)
InvalidArgumentError Traceback (most recent call last) ~/dss/code-envs/python/kge/lib64/python3.6/site-packages/tensorflow_core/python/client/session.py in _do_call(self, fn, args) 1364 try: -> 1365 return fn(args) 1366 except errors.OpError as e:
~/dss/code-envs/python/kge/lib64/python3.6/site-packages/tensorflow_core/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata) 1349 return self._call_tf_sessionrun(options, feed_dict, fetch_list, -> 1350 target_list, run_metadata) 1351
~/dss/code-envs/python/kge/lib64/python3.6/site-packages/tensorflow_core/python/client/session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata) 1442 fetch_list, target_list, -> 1443 run_metadata) 1444
InvalidArgumentError: ValueError:
generator
yielded an element of shape (3,) where an element of shape (?, 3) was expected. Traceback (most recent call last):File "/home/dataiku/dss/code-envs/python/kge/lib64/python3.6/site-packages/tensorflow_core/python/ops/script_ops.py", line 235, in call ret = func(*args)
File "/home/dataiku/dss/code-envs/python/kge/lib64/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 630, in generator_py_func "of shape %s was expected." % (ret_array.shape, expected_shape))
ValueError:
generator
yielded an element of shape (3,) where an element of shape (?, 3) was expected.During handling of the above exception, another exception occurred:
InvalidArgumentError Traceback (most recent call last)