Closed jasperzhong closed 1 year ago
SAR (MLSys '22)这篇文章其实大概意思也和P3 #285 类似,不过解决的问题不一样. P3是为了降低通信量,SAR好像是为了跑full-batch training,降低memory consumption的,有点类似DNN优化里面常用的Rematerialization.
这里面最重要的一个发现是,GNN的AGG操作是可以incrementally做的. AGG({A, B, C}) = AGG({A, AGG({B, C})})
对于GCN, GraphSAGE都是message求和,很easy. 甚至GAT都是可以的. 先算exp(e_j) * z_j,一个一个求和,最后整体normalize一个sum(exp(e_j))就行了.
所以压根不需要采样 ,对于一个node的forward pass,直接iterate all partitions,寻找其邻居,然后sequentially aggregate就行了.
对于backward pass,分两种情况,如果算grad需要z,就需要fetch,比如GAT的情况;如果不需要z,就不fetch,不引入额外的通信. 这一点的确和Rematerialization挺像.
而且发现没有,这里传的不是raw features,传的都是处理过的z.
文章还提到了其他一些tricks,比如GAT的kernel fusion,stable softmax等等.
实验部分提到他们还是用METIS partition graphs.
在obgn-product上训练GraphSAGE,peak memory好像也没降低多少啊,训练速度在只有4个partition还慢一些.
GAT慢挺多的,因为backward需要额外的通信.
https://arxiv.org/pdf/2111.06483.pdf