Open wenruij opened 4 years ago
label: 长度25的multi-hot编码(大部分只有一个类别,为ont-hot) 特征: 长度100的 dense vector Graph Load Finish! Node Count:523375 Edge Count:37921281
utils_flags.set_defaults( mode='train', train_node_type=1, data_dir='/data1/jwr/project/gcn/data/20191217', max_id=523374, label_idx=0, label_dim=25, feature_idx=3, feature_dim=100, dim=100, fanouts=[10, 10], aggregator='meanpool', use_residual=False, batch_size=256, optimizer='adam', learning_rate=0.008, store_learning_rate=0.002, num_epochs=2, model_dir='ckpt/' + str(int(time.time())), model='scalable_sage')
INFO:tensorflow:f1 = 0.57615125, loss = 0.072720595, step = 1940 (0.170 sec) INFO:tensorflow:f1 = 0.576606, loss = 0.06780708, step = 1960 (0.183 sec) INFO:tensorflow:f1 = 0.57674795, loss = 0.07061373, step = 1980 (0.175 sec) INFO:tensorflow:f1 = 0.57738733, loss = 0.067535296, step = 2000 (0.178 sec) INFO:tensorflow:f1 = 0.5778596, loss = 0.06816011, step = 2020 (0.174 sec) INFO:tensorflow:f1 = 0.5778974, loss = 0.07258644, step = 2040 (0.179 sec) INFO:tensorflow:Saving checkpoints for 2044 into ckpt/1577072290/model.ckpt.
训练好之后,将mode='save_embedding'导出节点向量,所有节点的embedding向量都一样,更换过batch_size, optimizer, learning_rate, aggregator等参数,全部都没用,导出节点的embedding向量依然一样。
数据说明
label: 长度25的multi-hot编码(大部分只有一个类别,为ont-hot) 特征: 长度100的 dense vector Graph Load Finish! Node Count:523375 Edge Count:37921281
训练代码
训练日志
导出向量出问题
训练好之后,将mode='save_embedding'导出节点向量,所有节点的embedding向量都一样,更换过batch_size, optimizer, learning_rate, aggregator等参数,全部都没用,导出节点的embedding向量依然一样。