jasperzhong / read-papers-and-code

My paper/code reading notes in Chinese
44 stars 3 forks source link

MLSys '22 | Sequential Aggregation and Rematerialization: Distributed Full-batch Training of Graph Neural Networks on Large Graphs #377

Closed jasperzhong closed 1 year ago

jasperzhong commented 1 year ago

https://arxiv.org/pdf/2111.06483.pdf

jasperzhong commented 1 year ago

SAR (MLSys '22)这篇文章其实大概意思也和P3 #285 类似,不过解决的问题不一样. P3是为了降低通信量,SAR好像是为了跑full-batch training,降低memory consumption的,有点类似DNN优化里面常用的Rematerialization.

image

这里面最重要的一个发现是,GNN的AGG操作是可以incrementally做的. AGG({A, B, C}) = AGG({A, AGG({B, C})})

对于GCN, GraphSAGE都是message求和,很easy. 甚至GAT都是可以的. 先算exp(e_j) * z_j,一个一个求和,最后整体normalize一个sum(exp(e_j))就行了.

所以压根不需要采样 ,对于一个node的forward pass,直接iterate all partitions,寻找其邻居,然后sequentially aggregate就行了.

image

对于backward pass,分两种情况,如果算grad需要z,就需要fetch,比如GAT的情况;如果不需要z,就不fetch,不引入额外的通信. 这一点的确和Rematerialization挺像.

而且发现没有,这里传的不是raw features,传的都是处理过的z.

文章还提到了其他一些tricks,比如GAT的kernel fusion,stable softmax等等.

实验部分提到他们还是用METIS partition graphs.

image

在obgn-product上训练GraphSAGE,peak memory好像也没降低多少啊,训练速度在只有4个partition还慢一些.

image

GAT慢挺多的,因为backward需要额外的通信.