fuxiAIlab / perCLTV

【TOIS2023】perCLTV: A General System for Personalized Customer Lifetime Value Prediction in Online Games
10 stars 5 forks source link

实际预测的时候图注意力层报错。 #2

Open leiyinghanguang opened 2 months ago

leiyinghanguang commented 2 months ago

for train_index, test_index in kfold.split(B, y1): print('train_index',train_index) print('test_index',test_index) train_index, val_index = train_test_split( train_index, test_size=0.1, random_state=seed_value)

mask_train = np.zeros(N, dtype=bool)
mask_val = np.zeros(N, dtype=bool)
mask_test = np.zeros(N, dtype=bool)
mask_train[train_index] = True
mask_val[val_index] = True
mask_test[test_index] = True

checkpoint_path = './model/checkpoint-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

if os.path.exists(checkpoint_dir):
    shutil.rmtree(checkpoint_dir)

early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                  patience=5,
                                                  mode='min')

best_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     monitor='val_loss',
                                                     verbose=1,
                                                     save_best_only=True,
                                                     save_weights_only=True,
                                                     mode='auto')

model = perCLTV(timestep=timestep, behavior_maxlen=maxlen)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
              loss={'output_1': tf.keras.losses.BinaryCrossentropy(),
                    'output_2': tf.keras.losses.MeanSquaredError()},
              loss_weights={'output_1': beta1, 'output_2': beta2},
              metrics={'output_1': tf.keras.metrics.AUC(),
                       'output_2': 'mae'})

model.fit([B, C, P, A], [y1, y2],
          validation_data=([B, C, P, A], [y1, y2], mask_val),
          sample_weight=mask_train,
          batch_size=N,
          epochs=1,
          shuffle=False,
          callbacks=[early_stopping, best_checkpoint],
          verbose=1)

# print('A:',A)
# print('B[0,:]:', B[0,:])

predictions = model.predict([B, C, P,A])
# predictions = model.predict([B[mask_val], C[mask_val], P[mask_val], A[mask_val]])
print('predictions:',predictions)                 Traceback (most recent call last):

File "J:\MetajoyAlogrithm\perCLTV-master\main.py", line 133, in predictions = model.predict([B, C, P,A]) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler raise e.with_traceback(filtered_tb) from None File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\tensorflow\python\eager\execute.py", line 52, in quick_execute tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node 'per_cltv/social_behavior_net/gat_conv/GatherV2' defined at (most recent call last): File "J:\MetajoyAlogrithm\perCLTV-master\main.py", line 133, in predictions = model.predict([B, C, P,A]) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler return fn(*args, kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2382, in predict tmp_batch_outputs = self.predict_function(iterator) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2169, in predict_function return step_function(self, iterator) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2155, in step_function outputs = model.distribute_strategy.run(run_step, args=(data,)) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2143, in run_step outputs = model.predict_step(data) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2111, in predict_step return self(x, training=False) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler return fn(*args, *kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 558, in call return super().call(args, kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler return fn(*args, kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\base_layer.py", line 1145, in call outputs = call_fn(inputs, *args, *kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler return fn(args, kwargs) File "J:\MetajoyAlogrithm\perCLTV-master\src\model.py", line 75, in call O = self.social_behavior_net([X, A]) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler return fn(*args, kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 558, in call return super().call(*args, *kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler return fn(args, kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\base_layer.py", line 1145, in call outputs = call_fn(inputs, *args, kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler return fn(args, kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\sequential.py", line 427, in call outputs = layer(inputs, kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler return fn(args, kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\base_layer.py", line 1145, in call outputs = call_fn(inputs, *args, *kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler return fn(args, **kwargs) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\conv.py", line 167, in _inner_check_dtypes File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\gat_conv.py", line 168, in call if mode == modes.SINGLE and K.is_sparse(a): File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\gat_conv.py", line 169, in call output, attn_coef = self._call_single(x, a) File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\gat_conv.py", line 213, in _call_single attn_for_self = tf.gather(attn_for_self, targets) Node: 'per_cltv/social_behavior_net/gat_conv/GatherV2' indices[3] = 33 is not in [0, 32) [[{{node per_cltv/social_behavior_net/gat_conv/GatherV2}}]] [Op:__inference_predict_function_29622] 报错,这一行 predictions = model.predict([B, C, P,A])

leiyinghanguang commented 2 months ago

图注意力层报错