alibaba / euler

A distributed graph deep learning framework.
Apache License 2.0
2.89k stars 559 forks source link

scalable_sage模型save_embedding导出向量全部都一模一样 #208

Open wenruij opened 4 years ago

wenruij commented 4 years ago

数据说明

label: 长度25的multi-hot编码(大部分只有一个类别,为ont-hot) 特征: 长度100的 dense vector Graph Load Finish! Node Count:523375 Edge Count:37921281

训练代码

  utils_flags.set_defaults(
      mode='train',
      train_node_type=1,
      data_dir='/data1/jwr/project/gcn/data/20191217',
      max_id=523374,
      label_idx=0,
      label_dim=25,
      feature_idx=3,
      feature_dim=100,
      dim=100,
      fanouts=[10, 10],
      aggregator='meanpool',
      use_residual=False,
      batch_size=256,
      optimizer='adam',
      learning_rate=0.008,
      store_learning_rate=0.002,
      num_epochs=2,
      model_dir='ckpt/' + str(int(time.time())),
      model='scalable_sage')

训练日志

INFO:tensorflow:f1 = 0.57615125, loss = 0.072720595, step = 1940 (0.170 sec)
INFO:tensorflow:f1 = 0.576606, loss = 0.06780708, step = 1960 (0.183 sec)
INFO:tensorflow:f1 = 0.57674795, loss = 0.07061373, step = 1980 (0.175 sec)
INFO:tensorflow:f1 = 0.57738733, loss = 0.067535296, step = 2000 (0.178 sec)
INFO:tensorflow:f1 = 0.5778596, loss = 0.06816011, step = 2020 (0.174 sec)
INFO:tensorflow:f1 = 0.5778974, loss = 0.07258644, step = 2040 (0.179 sec)
INFO:tensorflow:Saving checkpoints for 2044 into ckpt/1577072290/model.ckpt.

导出向量出问题

训练好之后,将mode='save_embedding'导出节点向量,所有节点的embedding向量都一样,更换过batch_size, optimizer, learning_rate, aggregator等参数,全部都没用,导出节点的embedding向量依然一样。