bfGraph / STGraph

🌟 Vertex Centric approach for building GNN/TGNNs
MIT License
12 stars 0 forks source link

Seastar - RGCN #11

Open JoelMathewC opened 1 year ago

JoelMathewC commented 1 year ago

Seastar's original implementation does not present a vertex centric program for RGCN, it rather uses a handwritten kernel in dgl-hack. Let's try to write a vertex-centric program for RGCN, this issue will track any issues faced along the way.

Details about the original implementation of RGCN can be found here.

JoelMathewC commented 1 year ago

Trying to run train.py by using dgl.RelGraphConv to see if the train file can execute without dgl-hack. Turns out there is a custom function add_edges_with_type that was implemented. Additionally had to make some modifications to input tensor sizes but I got RGCN to run successfully on DGL.

The training log for DGL-RGCN on aifb dataset is

Train Accuracy: 0.2857 | Train Loss: 48903460.0000 | Validation Accuracy: 0.3214 | Validation loss: 56560608.0000
Train Accuracy: 0.2857 | Train Loss: 48815928.0000 | Validation Accuracy: 0.3214 | Validation loss: 56243768.0000
Train Accuracy: 0.2857 | Train Loss: 47918792.0000 | Validation Accuracy: 0.3214 | Validation loss: 54764828.0000
Epoch 00003 | Train Forward Time(s) 0.0026 | Backward Time(s) 0.0017
Train Accuracy: 0.2946 | Train Loss: 47904824.0000 | Validation Accuracy: 0.3214 | Validation loss: 54428792.0000
Epoch 00004 | Train Forward Time(s) 0.0026 | Backward Time(s) 0.0018
Train Accuracy: 0.2946 | Train Loss: 48096584.0000 | Validation Accuracy: 0.3214 | Validation loss: 54330104.0000
Epoch 00005 | Train Forward Time(s) 0.0029 | Backward Time(s) 0.0022
Train Accuracy: 0.3036 | Train Loss: 48391628.0000 | Validation Accuracy: 0.3214 | Validation loss: 54448644.0000
Epoch 00006 | Train Forward Time(s) 0.0029 | Backward Time(s) 0.0020
Train Accuracy: 0.3036 | Train Loss: 48688448.0000 | Validation Accuracy: 0.3214 | Validation loss: 54529480.0000
Epoch 00007 | Train Forward Time(s) 0.0028 | Backward Time(s) 0.0021
Train Accuracy: 0.3125 | Train Loss: 48979032.0000 | Validation Accuracy: 0.3214 | Validation loss: 54579388.0000
Epoch 00008 | Train Forward Time(s) 0.0027 | Backward Time(s) 0.0018
Train Accuracy: 0.3125 | Train Loss: 49252564.0000 | Validation Accuracy: 0.3214 | Validation loss: 54631652.0000
Epoch 00009 | Train Forward Time(s) 0.0028 | Backward Time(s) 0.0018
Train Accuracy: 0.3125 | Train Loss: 49504344.0000 | Validation Accuracy: 0.3214 | Validation loss: 54682868.0000
Epoch 00010 | Train Forward Time(s) 0.0030 | Backward Time(s) 0.0020
Train Accuracy: 0.3125 | Train Loss: 49736864.0000 | Validation Accuracy: 0.3214 | Validation loss: 54717436.0000

However it seems like the model is not training.

JoelMathewC commented 1 year ago

Fixed the DGL code on the aifb dataset. This dataset does not have node features so instead we label each node and then assign it a random feature from a dictionary of features as given torch.nn.Embedding. The task is to predict the type of node.

The updated training log is

Epoch 00000 | Train Accuracy: 0.3393 | Train Loss: 1.3365 | Validation Accuracy: 0.3214 | Validation loss: 1.3165
Epoch 00001 | Train Accuracy: 0.3929 | Train Loss: 1.2702 | Validation Accuracy: 0.3929 | Validation loss: 1.2714
Epoch 00002 | Train Accuracy: 0.4375 | Train Loss: 1.2172 | Validation Accuracy: 0.4286 | Validation loss: 1.2256
Epoch 00003 | Train Accuracy: 0.6429 | Train Loss: 1.1597 | Validation Accuracy: 0.6786 | Validation loss: 1.1632
Epoch 00004 | Train Accuracy: 0.6964 | Train Loss: 1.1040 | Validation Accuracy: 0.7500 | Validation loss: 1.0950
Epoch 00005 | Train Accuracy: 0.7321 | Train Loss: 1.0607 | Validation Accuracy: 0.7500 | Validation loss: 1.0387
Epoch 00006 | Train Accuracy: 0.7768 | Train Loss: 1.0308 | Validation Accuracy: 0.7857 | Validation loss: 1.0008
Epoch 00007 | Train Accuracy: 0.8036 | Train Loss: 1.0075 | Validation Accuracy: 0.7500 | Validation loss: 0.9773
Epoch 00008 | Train Accuracy: 0.8125 | Train Loss: 0.9861 | Validation Accuracy: 0.8214 | Validation loss: 0.9633
Epoch 00009 | Train Accuracy: 0.8214 | Train Loss: 0.9651 | Validation Accuracy: 0.7857 | Validation loss: 0.9560
Epoch 00010 | Train Accuracy: 0.8571 | Train Loss: 0.9454 | Validation Accuracy: 0.7857 | Validation loss: 0.9532
Epoch 00011 | Train Accuracy: 0.8571 | Train Loss: 0.9280 | Validation Accuracy: 0.7500 | Validation loss: 0.9518
Epoch 00012 | Train Accuracy: 0.8661 | Train Loss: 0.9129 | Validation Accuracy: 0.7500 | Validation loss: 0.9480
Epoch 00013 | Train Accuracy: 0.8750 | Train Loss: 0.8998 | Validation Accuracy: 0.8214 | Validation loss: 0.9403
Epoch 00014 | Train Accuracy: 0.8750 | Train Loss: 0.8888 | Validation Accuracy: 0.8571 | Validation loss: 0.9298
Epoch 00015 | Train Accuracy: 0.8839 | Train Loss: 0.8798 | Validation Accuracy: 0.8929 | Validation loss: 0.9183
Epoch 00016 | Train Accuracy: 0.8839 | Train Loss: 0.8723 | Validation Accuracy: 0.8929 | Validation loss: 0.9078
Epoch 00017 | Train Accuracy: 0.8839 | Train Loss: 0.8660 | Validation Accuracy: 0.8929 | Validation loss: 0.8992
Epoch 00018 | Train Accuracy: 0.8839 | Train Loss: 0.8601 | Validation Accuracy: 0.8929 | Validation loss: 0.8923
Epoch 00019 | Train Accuracy: 0.8839 | Train Loss: 0.8544 | Validation Accuracy: 0.8929 | Validation loss: 0.8874
Epoch 00020 | Train Accuracy: 0.9018 | Train Loss: 0.8487 | Validation Accuracy: 0.8929 | Validation loss: 0.8839
Epoch 00021 | Train Accuracy: 0.9107 | Train Loss: 0.8434 | Validation Accuracy: 0.8929 | Validation loss: 0.8817
Epoch 00022 | Train Accuracy: 0.9196 | Train Loss: 0.8384 | Validation Accuracy: 0.8929 | Validation loss: 0.8804
Epoch 00023 | Train Accuracy: 0.9196 | Train Loss: 0.8336 | Validation Accuracy: 0.8929 | Validation loss: 0.8794
Epoch 00024 | Train Accuracy: 0.9196 | Train Loss: 0.8298 | Validation Accuracy: 0.8929 | Validation loss: 0.8785
JoelMathewC commented 1 year ago

The modified code has been moved to the exp/rgcn/dgl folder in the new seastar/rgcn branch. Using lr of 0.001 the training log is as follows

Train Accuracy: 0.2143 | Train Loss: 1.4091 | Validation Accuracy: 0.2500 | Validation loss: 1.3989
Train Accuracy: 0.7857 | Train Loss: 0.9575 | Validation Accuracy: 0.8571 | Validation loss: 0.8758
Train Accuracy: 0.8214 | Train Loss: 0.9158 | Validation Accuracy: 0.9286 | Validation loss: 0.8145
Epoch 00003 | Train Forward Time(s) 0.0689 | Backward Time(s) 0.1214
Train Accuracy: 0.8661 | Train Loss: 0.8774 | Validation Accuracy: 0.9286 | Validation loss: 0.8150
Epoch 00004 | Train Forward Time(s) 0.0696 | Backward Time(s) 0.1222
Train Accuracy: 0.8661 | Train Loss: 0.8770 | Validation Accuracy: 0.9286 | Validation loss: 0.8157
Epoch 00005 | Train Forward Time(s) 0.0702 | Backward Time(s) 0.1221
Train Accuracy: 0.8661 | Train Loss: 0.8773 | Validation Accuracy: 0.9286 | Validation loss: 0.8161
Epoch 00006 | Train Forward Time(s) 0.0701 | Backward Time(s) 0.1212
Train Accuracy: 0.8661 | Train Loss: 0.8774 | Validation Accuracy: 0.9286 | Validation loss: 0.8159
Epoch 00007 | Train Forward Time(s) 0.0692 | Backward Time(s) 0.1210
Train Accuracy: 0.8661 | Train Loss: 0.8770 | Validation Accuracy: 0.9286 | Validation loss: 0.8155
Epoch 00008 | Train Forward Time(s) 0.0692 | Backward Time(s) 0.1227
Train Accuracy: 0.8750 | Train Loss: 0.8695 | Validation Accuracy: 0.9286 | Validation loss: 0.8143
Epoch 00009 | Train Forward Time(s) 0.0706 | Backward Time(s) 0.1212
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9286 | Validation loss: 0.8117
Epoch 00010 | Train Forward Time(s) 0.0699 | Backward Time(s) 0.1214
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9286 | Validation loss: 0.8063
Epoch 00011 | Train Forward Time(s) 0.0692 | Backward Time(s) 0.1210
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9286 | Validation loss: 0.7981
Epoch 00012 | Train Forward Time(s) 0.0692 | Backward Time(s) 0.1227
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9643 | Validation loss: 0.7901
Epoch 00013 | Train Forward Time(s) 0.0702 | Backward Time(s) 0.1224
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9643 | Validation loss: 0.7849
Epoch 00014 | Train Forward Time(s) 0.0702 | Backward Time(s) 0.1209
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9643 | Validation loss: 0.7822
Epoch 00015 | Train Forward Time(s) 0.0694 | Backward Time(s) 0.1217
Train Accuracy: 0.8750 | Train Loss: 0.8685 | Validation Accuracy: 0.9643 | Validation loss: 0.7808
Epoch 00016 | Train Forward Time(s) 0.0690 | Backward Time(s) 0.1226
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7800
Epoch 00017 | Train Forward Time(s) 0.0712 | Backward Time(s) 0.1218
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7797
Epoch 00018 | Train Forward Time(s) 0.0695 | Backward Time(s) 0.1217
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7796
Epoch 00019 | Train Forward Time(s) 0.0694 | Backward Time(s) 0.1222
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
Epoch 00020 | Train Forward Time(s) 0.0718 | Backward Time(s) 0.1224
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
Epoch 00021 | Train Forward Time(s) 0.0708 | Backward Time(s) 0.1217
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
Epoch 00022 | Train Forward Time(s) 0.0710 | Backward Time(s) 0.1216
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
Epoch 00023 | Train Forward Time(s) 0.0711 | Backward Time(s) 0.1220
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
Epoch 00024 | Train Forward Time(s) 0.0713 | Backward Time(s) 0.1229
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
max memory allocated 9701254656
Test Accuracy: 0.9444 | Test loss: 0.7994

Mean forward time: 0.070180
Mean backward time: 0.121936
^^^9.034997^^^0.192116
JoelMathewC commented 1 year ago

We've isolated quite a few changes that need to be made in the codegen portion to support RGCN. We stopped when posed with the question of whether there was a benefit to moving this support into the compiler. Technically it is possible to split a relational graph into homogenous subgraphs that can be processed by seastar.

Note: To handle the need for input feature vectors multiplied with multiple weight matrices torch.bmm can be used.