Closed bzbzbz22 closed 7 months ago
In addition, the problem also occurs when we use some functions in DGL, like dgl.function.u_mul_e()
. The DGL we use is version 0.7.2. And I test in the 0.8.2 version, the problem also occurs.
Can you please help me to solve this problem, it is very important for me. @Rhett-Ying @mufeili
Please share more details of your config for dgl.nn.GraphConv
. how did you init it, what are the arguments when calling forward
?
Thank you for your reply.
The GCN we built using DGL is as follows:
from dgl.nn.pytorch import GraphConv
class GCN_dgl(nn.Module):
def __init__(self,
g,
nfeat,
nlayers,
nhid,
nclass,
dropout):
super(GCN_dgl, self).__init__()
self.g = g
self.layers = nn.ModuleList()
self.layers.append(GraphConv(nfeat, nhid))
for i in range(nlayers - 1):
self.layers.append(GraphConv(nhid, nhid))
self.layers.append(GraphConv(nhid, nclass))
self.dropout = nn.Dropout(p=dropout)
def forward(self, g, x):
h = x
for i, layer in enumerate(self.layers):
if i!=0:
h = self.dropout(h)
h = layer(g, h)
# return h
return F.softmax(h, 1)
We implement this GCN like this:
model = GCN_dgl(nfeat=features.shape[1],
nhid=args.hidden_dim,
nclass=class_num[args.dataset],
nlayers=args.num_layers,
dropout=args.dropout)
We use this GCN like this:
logits=model(g, features)
The code and data are in here
You can download them and run by python main.py --dataset cora
directly.
Thank you very much.
what if disable dropout
?
We bulid the GCN without dropout:
class GCN_dgl(nn.Module):
def __init__(self,
nfeat,
nlayers,
nhid,
nclass,
):
super(GCN_dgl, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(GraphConv(nfeat, nhid))
for i in range(nlayers - 1):
self.layers.append(GraphConv(nhid, nhid))
self.layers.append(GraphConv(nhid, nclass))
def forward(self, g, x):
h = x
for i, layer in enumerate(self.layers):
h = layer(g, h)
return F.softmax(h, 1)
The issue also occurs:
Repeat 0:
Run 0, test_acc: 0.5676
Run 1, test_acc: 0.2973
Run 2, test_acc: 0.4595
Run 3, test_acc: 0.5405
Run 4, test_acc: 0.4595
Repeat 1:
Run 0, test_acc: 0.5405
Run 1, test_acc: 0.4054
Run 2, test_acc: 0.4595
Run 3, test_acc: 0.5405
Run 4, test_acc: 0.4595
Did you run this on CPU or GPU? If on GPU, could you try on CPU and see if the issue still occurs?
Did you run this on CPU or GPU? If on GPU, could you try on CPU and see if the issue still occurs
I run it on CPU, and the issue also occurs.
Repeat 0:
Run 0, test_acc: 0.5676
Run 1, test_acc: 0.2973
Run 2, test_acc: 0.4595
Run 3, test_acc: 0.5405
Run 4, test_acc: 0.4595
Repeat 1:
Run 0, test_acc: 0.5405
Run 1, test_acc: 0.4054
Run 2, test_acc: 0.4595
Run 3, test_acc: 0.5405
Run 4, test_acc: 0.4595
Hello, can this problem be solved? We think the reproduction of the experimental results is very important, especially for scientific research. We hope this problem can get your attention. Thanks. @BarclayII @Rhett-Ying
The non-deterministic result is caused by weight=True
when instantiate GraphConv
. This arguments is True
in default which will init by below code:
https://github.com/dmlc/dgl/blob/0f0e7c7fad6ec1399a67881ce9a7e60d14cfe1a0/python/dgl/nn/pytorch/conv/graphconv.py#L283-L284
Then the weight
will be re-init by https://github.com/dmlc/dgl/blob/0f0e7c7fad6ec1399a67881ce9a7e60d14cfe1a0/python/dgl/nn/pytorch/conv/graphconv.py#L312-L313
And these calls did not return deterministic though you have manually set seeds. So self.weight
differs in each instantiation. You could print out model.state_dict()
to verify this.
So in order to obtain deterministic result, please specify weight=False
when instantiate GraphConv
and pass deterministic weight when call forward()
.
This issue has been automatically marked as stale due to lack of activity. It will be closed if no further activity occurs. Thank you
The non-deterministic result is caused by
weight=True
when instantiateGraphConv
. This arguments isTrue
in default which will init by below code:Then the
weight
will be re-init by https://github.com/dmlc/dgl/blob/0f0e7c7fad6ec1399a67881ce9a7e60d14cfe1a0/python/dgl/nn/pytorch/conv/graphconv.py#L312-L313And these calls did not return deterministic though you have manually set seeds. So
self.weight
differs in each instantiation. You could print outmodel.state_dict()
to verify this. So in order to obtain deterministic result, please specifyweight=False
when instantiateGraphConv
and pass deterministic weight when callforward()
.
This is the initialization of the weight, fixed random seed, should also be fixed, if set to False, input and output feature dimension must be consistent.
I checked the code in the repo and found the setup is not correct. As @Rhett-Ying mentioned, the model weight will be re-initialized each time you created a model
object. If you want to have a deterministic behavior. You could either reset all the seeds right before the model is created each time, or initialize the model once, checkpoint the weights and reload them for each repeated runs.
This issue has been automatically marked as stale due to lack of activity. It will be closed if no further activity occurs. Thank you
When using dgl.nn.pytorch.GraphConv, I meet the same problem with @bzbzbz22. Actually, I find that the problem may be caused by the number of edges in the graph. When I reduce the number of edges to 10000, then the obtained results are deterministic.
When using dgl.nn.pytorch.GraphConv, I meet the same problem with @bzbzbz22. Actually, I find that the problem may be caused by the number of edges in the graph. When I reduce the number of edges to 10000, then the obtained results are deterministic.
When I use CPU, the results are fixed each time, while when I use GPU, the results are uncertain. In my opinion, this problem is caused by the fact that some operations of addition and multiplication are called when using GPU in dgl. These operations will be performed in parallel on GPU due to the different order, and the accuracy of float32 and float64 is limited, so the results are inconsistent.
Reopened as more people are getting similar problems. @zxqaaaaa @jinfy Could you provide a runnable script for us to reproduce?
When using dgl.nn.pytorch.GraphConv, I meet the same problem with @bzbzbz22. Actually, I find that the problem may be caused by the number of edges in the graph. When I reduce the number of edges to 10000, then the obtained results are deterministic.
When I use CPU, the results are fixed each time, while when I use GPU, the results are uncertain. In my opinion, this problem is caused by the fact that some operations of addition and multiplication are called when using GPU in dgl. These operations will be performed in parallel on GPU due to the different order, and the accuracy of float32 and float64 is limited, so the results are inconsistent.
I also run the code in GPU. When I reduce the number of edges, the obtained results are deterministic. Maybe you could try it.
Reopened as more people are getting similar problems. @zxqaaaaa @jinfy Could you provide a runnable script for us to reproduce?
We released a demo to do link prediction task using dgl. The code are in here
Reopened as more people are getting similar problems. @zxqaaaaa @jinfy Could you provide a runnable script for us to reproduce?
Has this problem been solved? The code are in here
Hi @jinfy , I've run your scripts. Confirmed that on CPU, the results are deterministic
CPU training log:
Repeat 0:
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
Test device: cpu
Run 0, test_acc: 0.6514107949057748
Run 100, test_acc: 0.8621059724624336
Run 200, test_acc: 0.8561375530648458
Run 300, test_acc: 0.8322153590440466
Run 400, test_acc: 0.8201617214348285
Run 500, test_acc: 0.8146011994339750
Run 600, test_acc: 0.8085635992003772
Run 700, test_acc: 0.8060847689854226
Run 800, test_acc: 0.8040857123604591
Run 900, test_acc: 0.8022870106241997
Repeat 1:
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
Test device: cpu
Run 0, test_acc: 0.6514107949057748
Run 100, test_acc: 0.8621059724624336
Run 200, test_acc: 0.8561375530648458
Run 300, test_acc: 0.8322153590440466
Run 400, test_acc: 0.8201617214348285
Run 500, test_acc: 0.8146011994339750
Run 600, test_acc: 0.8085635992003772
Run 700, test_acc: 0.8060847689854226
Run 800, test_acc: 0.8040857123604591
Run 900, test_acc: 0.8022870106241997
Repeat 2:
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
Test device: cpu
Run 0, test_acc: 0.6514107949057748
Run 100, test_acc: 0.8621059724624336
Run 200, test_acc: 0.8561375530648458
Run 300, test_acc: 0.8322153590440466
Run 400, test_acc: 0.8201617214348285
Run 500, test_acc: 0.8146011994339750
Run 600, test_acc: 0.8085635992003772
Run 700, test_acc: 0.8060847689854226
Run 800, test_acc: 0.8040857123604591
Run 900, test_acc: 0.8022870106241997
Repeat 3:
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
Test device: cpu
Run 0, test_acc: 0.6514107949057748
Run 100, test_acc: 0.8621059724624336
Run 200, test_acc: 0.8561375530648458
Run 300, test_acc: 0.8322153590440466
Run 400, test_acc: 0.8201617214348285
Run 500, test_acc: 0.8146011994339750
Run 600, test_acc: 0.8085635992003772
Run 700, test_acc: 0.8060847689854226
Run 800, test_acc: 0.8040857123604591
Run 900, test_acc: 0.8022870106241997
Repeat 4:
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
Test device: cpu
Run 0, test_acc: 0.6514107949057748
Run 100, test_acc: 0.8621059724624336
Run 200, test_acc: 0.8561375530648458
Run 300, test_acc: 0.8322153590440466
Run 400, test_acc: 0.8201617214348285
Run 500, test_acc: 0.8146011994339750
Run 600, test_acc: 0.8085635992003772
Run 700, test_acc: 0.8060847689854226
Run 800, test_acc: 0.8040857123604591
Run 900, test_acc: 0.8022870106241997
However, on GPU, the training results begin to diverge after 100 epochs.
GPU logs:
Repeat 0:
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
Test device: cuda
Run 0, test_acc: 0.6514107949057748
Run 100, test_acc: 0.8621059724624336
Run 200, test_acc: 0.8563001729520900
Run 300, test_acc: 0.8313806967498484
Run 400, test_acc: 0.8197619101098359
Run 500, test_acc: 0.8151178095730104
Run 600, test_acc: 0.8088735652837987
Run 700, test_acc: 0.8056139799195885
Run 800, test_acc: 0.8034837492419307
Run 900, test_acc: 0.8016185620269086
Repeat 1:
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
Test device: cuda
Run 0, test_acc: 0.6514107949057748
Run 100, test_acc: 0.8621059724624336
Run 200, test_acc: 0.8562992744996744
Run 300, test_acc: 0.8320060196311853
Run 400, test_acc: 0.8194465533119202
Run 500, test_acc: 0.8145661597897622
Run 600, test_acc: 0.8093209945868243
Run 700, test_acc: 0.8061961770849712
Run 800, test_acc: 0.8041504009343906
Run 900, test_acc: 0.8027101817120011
Repeat 2:
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
Test device: cuda
Run 0, test_acc: 0.6514107949057748
Run 100, test_acc: 0.8621059724624336
Run 200, test_acc: 0.8563028683093371
Run 300, test_acc: 0.8300420026504346
Run 400, test_acc: 0.8203746546573526
Run 500, test_acc: 0.8151169111205948
Run 600, test_acc: 0.8088331349250917
Run 700, test_acc: 0.8054774151524000
Run 800, test_acc: 0.8037703555625436
Run 900, test_acc: 0.8019249343006670
Repeat 3:
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
Test device: cuda
Run 0, test_acc: 0.6514107949057748
Run 100, test_acc: 0.8621059724624336
Run 200, test_acc: 0.8561375530648458
Run 300, test_acc: 0.8306888883897486
Run 400, test_acc: 0.8206181352620112
Run 500, test_acc: 0.8149291345657106
Run 600, test_acc: 0.8090658341007614
Run 700, test_acc: 0.8058987893353700
Run 800, test_acc: 0.8048637721524674
Run 900, test_acc: 0.8025879921834638
Repeat 4:
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
Test device: cuda
Run 0, test_acc: 0.6514107949057748
Run 100, test_acc: 0.8621059724624336
Run 200, test_acc: 0.8562992744996744
Run 300, test_acc: 0.8308200624424430
Run 400, test_acc: 0.8187574403090675
Run 500, test_acc: 0.8150117921879563
Run 600, test_acc: 0.8091422025560970
Run 700, test_acc: 0.8061108241054783
Run 800, test_acc: 0.8036077356752993
Run 900, test_acc: 0.8021765009770671
Trying to find out the root cause.
My environment: dgl 0.9.1, CUDA 11.3
I checked the code in the repo and found the setup is not correct. As @Rhett-Ying mentioned, the model weight will be re-initialized each time you created a
model
object. If you want to have a deterministic behavior. You could either reset all the seeds right before the model is created each time, or initialize the model once, checkpoint the weights and reload them for each repeated runs.
I stored the weights before training my model as you told and it worked! Now the model is deterministic. Bellow is my code snippet:
#checkpoint the model weights
torch.save(base_model.state_dict(), 'base_model.pth')
#load the model weights
base_model.load_state_dict(torch.load('base_model.pth'))
Root cause confirmed in #7241 . The underlying cusparse algorithm DGL currently uses is not deterministic. It will requires some fixes in the C++ library. Closing this issue. Let's move future discussion to there.
🐛 Bug
In our experiment, we split the training/validation/test set for a graph data set (taking Cora as an example) many times. Each split was trained in a run and the classification accuracy was calculated. We found that when we use
dgl.nn.pytorch.GraphConv
to construct the GCN, although all random seeds (numpy, torch, dgl) are fixed, the experimental results of repeated experiments are still inconsistent.The experiment code and data are in here
Run
python main.py --use_dgl --dataset cora
, the results are:Run
python main.py --use_dgl --dataset texas
, the results are:As shown, for the two datasets: Cora and Texas, the experimental results of two repeats are inconsistent.
We also build another GCN that doesn't use
dgl.nn.pytorch.GraphConv
Run
python main.py --dataset cora
, the results are:Run
python main.py --dataset texas
, the results are:As shown above, for GCN without A. The experimental results of the two repeats are consistent. The poor accuracy of the model is due to the lack of normalization and other operations in the model, which do not affect whether the deterministic results are generated.