RUCAIBox / RecBole-GNN

Efficient and extensible GNNs enhanced recommender library based on RecBole.
MIT License
167 stars 37 forks source link

FEA: add sparse tensor support for ngcf_conv and lightgcn_conv #75

Closed downeykking closed 10 months ago

downeykking commented 10 months ago

为ngcf conv层和lightgcn conv层添加了稀疏矩阵支持,性能基本上维持不变,在测试结果上显存占用变为1/5,提速5x。

对general recommender下面的所有方法都进行了适配(如果用到了ngcf conv或lightgcn conv),所以基于这几个backbone的模型都可以实现提速。为了更好兼容现有代码逻辑,需要手动设置参数才会开启加速。 assert config["enable_sparse"] in [True, False, None]

具体使用方法:

需要安装torch_sparse

设置新参数enable_sparse==True来开启稀疏矩阵支持

如果开启了enable_sparse==True但torch_sparse不可用,会使用原来的dense edge_index

序列推荐其实也有卷积层,也可以进行加速,但是我之前做序列推荐做的比较少,我有时间再研究一下,当作一个todo计划 :)

主要的测试结果: 对于NGCF,15个epoch下,ml-1m数据集 Sparse:占用显存0.18G,训练总时间33.85s Origin:占用显存1.01G,训练总时间155.79s

Sparse:
best valid : OrderedDict([('recall@10', 0.1096), ('recall@20', 0.1819), ('recall@50', 0.3257), ('mrr@10', 0.2796), ('mrr@20', 0.2888), ('mrr@50', 0.2933), ('ndcg@10', 0.1474), ('ndcg@20', 0.1628), ('ndcg@50', 0.2071), ('hit@10', 0.5884), ('hit@20', 0.7201), ('hit@50', 0.8541), ('precision@10', 0.1162), ('precision@20', 0.0988), ('precision@50', 0.0732)])
test result: OrderedDict([('recall@10', 0.124), ('recall@20', 0.1964), ('recall@50', 0.3377), ('mrr@10', 0.3341), ('mrr@20', 0.3422), ('mrr@50', 0.346), ('ndcg@10', 0.1817), ('ndcg@20', 0.1901), ('ndcg@50', 0.2296), ('hit@10', 0.6237), ('hit@20', 0.7398), ('hit@50', 0.8548), ('precision@10', 0.1429), ('precision@20', 0.1152), ('precision@50', 0.0815)])

Origin:
best valid : OrderedDict([('recall@10', 0.1085), ('recall@20', 0.1782), ('recall@50', 0.3184), ('mrr@10', 0.2797), ('mrr@20', 0.2882), ('mrr@50', 0.2927), ('ndcg@10', 0.1456), ('ndcg@20', 0.1603), ('ndcg@50', 0.2039), ('hit@10', 0.5871), ('hit@20', 0.7095), ('hit@50', 0.8456), ('precision@10', 0.1149), ('precision@20', 0.098), ('precision@50', 0.0731)])
test result: OrderedDict([('recall@10', 0.119), ('recall@20', 0.1906), ('recall@50', 0.3291), ('mrr@10', 0.3266), ('mrr@20', 0.3349), ('mrr@50', 0.3386), ('ndcg@10', 0.176), ('ndcg@20', 0.1848), ('ndcg@50', 0.2235), ('hit@10', 0.6159), ('hit@20', 0.7339), ('hit@50', 0.8466), ('precision@10', 0.1391), ('precision@20', 0.1136), ('precision@50', 0.0803)])

对于LightGCN,15个epoch下,ml-1m数据集 Sparse:占用显存0.17G,训练总时间23.99s Origin:占用显存1.02G,训练总时间102.95s

Sparse:
best valid : OrderedDict([('recall@10', 0.0891), ('recall@20', 0.1434), ('recall@50', 0.25), ('mrr@10', 0.25), ('mrr@20', 0.2587), ('mrr@50', 0.2632), ('ndcg@10', 0.1242), ('ndcg@20', 0.1326), ('ndcg@50', 0.1641), ('hit@10', 0.5214), ('hit@20', 0.6464), ('hit@50', 0.7824), ('precision@10', 0.0973), ('precision@20', 0.08), ('precision@50', 0.0583)])
test result: OrderedDict([('recall@10', 0.0956), ('recall@20', 0.1526), ('recall@50', 0.2579), ('mrr@10', 0.2876), ('mrr@20', 0.296), ('mrr@50', 0.3002), ('ndcg@10', 0.1452), ('ndcg@20', 0.1502), ('ndcg@50', 0.1783), ('hit@10', 0.5424), ('hit@20', 0.663), ('hit@50', 0.7912), ('precision@10', 0.1124), ('precision@20', 0.0905), ('precision@50', 0.0633)])

Origin:
best valid : OrderedDict([('recall@10', 0.0883), ('recall@20', 0.1433), ('recall@50', 0.2494), ('mrr@10', 0.2507), ('mrr@20', 0.2595), ('mrr@50', 0.2639), ('ndcg@10', 0.1242), ('ndcg@20', 0.1325), ('ndcg@50', 0.164), ('hit@10', 0.519), ('hit@20', 0.6454), ('hit@50', 0.7814), ('precision@10', 0.0974), ('precision@20', 0.08), ('precision@50', 0.0585)])
test result: OrderedDict([('recall@10', 0.0955), ('recall@20', 0.1524), ('recall@50', 0.258), ('mrr@10', 0.2864), ('mrr@20', 0.2949), ('mrr@50', 0.2992), ('ndcg@10', 0.1451), ('ndcg@20', 0.1499), ('ndcg@50', 0.1782), ('hit@10', 0.5406), ('hit@20', 0.6605), ('hit@50', 0.7908), ('precision@10', 0.1129), ('precision@20', 0.0905), ('precision@50', 0.0635)])

对于SGL,15个epoch下,ml-1m数据集 Sparse:占用显存0.60G,训练总时间70.84s Origin:占用显存1.09G,训练总时间382.86s

Sparse:
best valid : OrderedDict([('recall@10', 0.1299), ('recall@20', 0.1991), ('recall@50', 0.3302), ('mrr@10', 0.3238), ('mrr@20', 0.3315), ('mrr@50', 0.3348), ('ndcg@10', 0.1721), ('ndcg@20', 0.184), ('ndcg@50', 0.2245), ('hit@10', 0.638), ('hit@20', 0.7461), ('hit@50', 0.846), ('precision@10', 0.1311), ('precision@20', 0.1058), ('precision@50', 0.0753)])
test result: OrderedDict([('recall@10', 0.1414), ('recall@20', 0.215), ('recall@50', 0.3435), ('mrr@10', 0.3893), ('mrr@20', 0.3962), ('mrr@50', 0.3992), ('ndcg@10', 0.2085), ('ndcg@20', 0.2144), ('ndcg@50', 0.2491), ('hit@10', 0.6661), ('hit@20', 0.7645), ('hit@50', 0.8531), ('precision@10', 0.1559), ('precision@20', 0.1221), ('precision@50', 0.0824)])

Origin:
best valid : OrderedDict([('recall@10', 0.1263), ('recall@20', 0.1969), ('recall@50', 0.3253), ('mrr@10', 0.3225), ('mrr@20', 0.3304), ('mrr@50', 0.3337), ('ndcg@10', 0.1699), ('ndcg@20', 0.1824), ('ndcg@50', 0.2219), ('hit@10', 0.6308), ('hit@20', 0.7433), ('hit@50', 0.8432), ('precision@10', 0.1299), ('precision@20', 0.1054), ('precision@50', 0.0748)])
test result: OrderedDict([('recall@10', 0.1403), ('recall@20', 0.2102), ('recall@50', 0.3398), ('mrr@10', 0.3893), ('mrr@20', 0.3961), ('mrr@50', 0.3992), ('ndcg@10', 0.2076), ('ndcg@20', 0.2124), ('ndcg@50', 0.2473), ('hit@10', 0.6608), ('hit@20', 0.7567), ('hit@50', 0.8476), ('precision@10', 0.1554), ('precision@20', 0.1214), ('precision@50', 0.0823)])
hyp1231 commented 10 months ago

orz 太猛了,新实现很优雅,学到很多!想请教一下提高效率的原因是因为 torch_sparse 底层封装了更高效的稀疏矩阵算子吗?因为感觉理论上感觉数据组织形式和原实现相似,但是原实现的计算部分可能都是 naive 的 python 运算?

downeykking commented 10 months ago

orz 太猛了,新实现很优雅,学到很多!想请教一下提高效率的原因是因为 torch_sparse 底层封装了更高效的稀疏矩阵算子吗?因为感觉理论上感觉数据组织形式和原实现相似,但是原实现的计算部分可能都是 naive 的 python 运算?

在我个人的理解中,原始实现是进行两步骤,分别是message,aggregate,其中aggregate是通过torch_scatter实现的,从source节点聚合信息到target节点,比如这样:

        row, col = edge_index
        # x_j为聚合后的x_j,按照0-row.max().item()+1顺序排列
        x_j = scatter.scatter(x[col], row, dim=0, dim_size=x.size(0), reduce='mean')

所以这时候要显式提供edge_index.size(1)x,(这个数通常还是非常大,尽管作者对torch_scatter也用c++优化了速度,但显存和速度还是相对不足),并在这个基础上做聚合。 而torch_sparse可以利用message_and_aggregate变成一步运算,并且聚合的时候也用的是稀疏矩阵和稠密矩阵的乘法, x_j = matmul(adj_t, x, reduce=self.aggr) 作者写的torch_sparse中用c++重写了很多算子,实现了基于GPU的稀疏矩阵乘法的快速前向和后向传递。

参考:https://pytorch-geometric.readthedocs.io/en/latest/advanced/sparse_tensor.html

hyp1231 commented 10 months ago

原来是这样,学到了,感谢解答!!