Open boliqq07 opened 3 years ago
Scatter is a non-deterministic operation by design since it makes use of atomic operations in which the order of aggregation is non-deterministic, leading to minor numerical differences. As an alternative, you can make use of the segment_csr
operation of torch_scatter
, see here.
For message passing layers, deterministic aggregation is only guaranteed when using SparseTensor
.
In the end, I wouldn't worry too much about it. In a deep learning scenario, such numerical instabilities should be only noticeable on really small datasets. Although it is correct that exact reproducible is no longer guaranteed when using non-deterministic operations, we can only enforce reproducibility for a single permutation (which does not exist in the context of graphs).
Scatter is a non-deterministic operation by design since it makes use of atomic operations in which the order of aggregation is non-deterministic, leading to minor numerical differences. As an alternative, you can make use of the
segment_csr
operation oftorch_scatter
, see here.For message passing layers, deterministic aggregation is only guaranteed when using
SparseTensor
.In the end, I wouldn't worry too much about it. In a deep learning scenario, such numerical instabilities should be only noticeable on really small datasets. Although it is correct that exact reproducible is no longer guaranteed when using non-deterministic operations, we can only enforce reproducibility for a single permutation (which does not exist in the context of graphs).
def __lift__(self, src, edge_index, dim):
if isinstance(edge_index, Tensor):
index = edge_index[dim]
return src.index_select(self.node_dim, index)
elif isinstance(edge_index, SparseTensor):
if dim == 1:
rowptr = edge_index.storage.rowptr()
rowptr = expand_left(rowptr, dim=self.node_dim, dims=src.dim())
return gather_csr(src, rowptr)
elif dim == 0:
col = edge_index.storage.col()
return src.index_select(self.node_dim, col)
raise ValueError
(This code in MessagePassing
)
I try to use SparseTensor
, segemt_scr
and gather_scr
, it is deterministic. But for one general network, The problem of un-repeatable still exist.
Finally, I find the problem in index_select
function of torch
, The index_select
could be the non-deterministic operation too.
Since I try to use [ ]
but not index_select
and it works though with reduced versatility.
Replace the
# return src.index_select(self.node_dim, col)
by
return src[col]
All the thing is OK.
Running torch.use_deterministic_algorithms(True)
should fix that as well, I guess :)
@rusty1s Hello, you said scatter
could result in indeterminacy, and thus minor numerical differences occur. But when it comes to scatter,
intrinsically this operation is permutation invariant as you said, then why there will be difference since it is not depend on the order by which the elements scatter? What I mean, you see, 1 + 2 + 3
, and 2 + 1 + 3
, then gradient and anything will be the same, right? Then why difference occurs?
As far as I know, cuda
version scatter
will use several parallel "sub things" to implement this operation, is it the reason?
Yes, this is due to how floating-point precision works. In case the ordering of operations is not deterministic internally, you may get slightly different outputs, e.g., (1 + 2) + 3 may be different from 1 + (2 + 3).
Then I understand, and thank you for your such an immediate reply. Very helpful.
Scatter is a non-deterministic operation by design since it makes use of atomic operations in which the order of aggregation is non-deterministic, leading to minor numerical differences. As an alternative, you can make use of the
segment_csr
operation oftorch_scatter
, see here. For message passing layers, deterministic aggregation is only guaranteed when usingSparseTensor
. In the end, I wouldn't worry too much about it. In a deep learning scenario, such numerical instabilities should be only noticeable on really small datasets. Although it is correct that exact reproducible is no longer guaranteed when using non-deterministic operations, we can only enforce reproducibility for a single permutation (which does not exist in the context of graphs).def __lift__(self, src, edge_index, dim): if isinstance(edge_index, Tensor): index = edge_index[dim] return src.index_select(self.node_dim, index) elif isinstance(edge_index, SparseTensor): if dim == 1: rowptr = edge_index.storage.rowptr() rowptr = expand_left(rowptr, dim=self.node_dim, dims=src.dim()) return gather_csr(src, rowptr) elif dim == 0: col = edge_index.storage.col() return src.index_select(self.node_dim, col) raise ValueError
(This code in
MessagePassing
)I try to use
SparseTensor
,segemt_scr
andgather_scr
, it is deterministic. But for one general network, The problem of un-repeatable still exist.Finally, I find the problem in
index_select
function oftorch
, Theindex_select
could be the non-deterministic operation too.Since I try to use
[ ]
but notindex_select
and it works though with reduced versatility.Replace the
# return src.index_select(self.node_dim, col)
byreturn src[col]
All the thing is OK.
Hi, could you please share how to use SparseTensor
? I am pretty struggling with it, say some functions like negative_sampling
only supports tensor rather than SparseTensor
. Where you convert normal tensor to SparseTensor
? Thanks:)
Did you take a look at https://pytorch-geometric.readthedocs.io/en/latest/advanced/sparse_tensor.html? You can convert between the two via:
row, col, edge_attr = adj_t.t().coo()
edge_index = torch.stack([row, col], dim=0)
Did you take a look at https://pytorch-geometric.readthedocs.io/en/latest/advanced/sparse_tensor.html? You can convert between the two via:
row, col, edge_attr = adj_t.t().coo() edge_index = torch.stack([row, col], dim=0)
Hi Matthias, thanks for the quick reply. I think in my case the big difference between rounds mostly coming from the training data batch generation, do you have any suggestion how to make batch generation (using NeighborLoader
, set all the seed as suggested above) deterministic?
I think NeighborLoader
should be deterministic if you set a manual seed. Which pyg-lib
/torch-sparse
version are you using?
I think
NeighborLoader
should be deterministic if you set a manual seed. Whichpyg-lib
/torch-sparse
version are you using?
0.6.13 for torch_sparse
, I tried to fix the seed, but I got one batch with different numbers of edge pairs and even different numbers of root_node for different rounds. Btw, my graph is heterogeneous, not sure if it has some impact.
Deterministic neighborhood sampling is available from torch-sparse
0.6.14 onwards, see here.
Deterministic neighborhood sampling is available from
torch-sparse
0.6.14 onwards, see here.
I think I found the problem. Just set device="cpu"
is not enough to disable cuda, I create a new environment for cpu version of torch and pyg, it is reproducible now. Thanks for the help:)
Did you take a look at https://pytorch-geometric.readthedocs.io/en/latest/advanced/sparse_tensor.html? You can convert between the two via:
row, col, edge_attr = adj_t.t().coo() edge_index = torch.stack([row, col], dim=0)
Hello!
I'm facing a similar issue and wonder if / how to use SparseTensors. What should be transformed to a sparse tensor to get a deterministic output?
I use graph = T.Compose([T.ToSparseTensor()])(graph)
to get an adjacency matrix of type <class 'torch_sparse.tensor.SparseTensor'>.
My graph.x remain of type <class 'torch.Tensor'> but it does not contains '0' and changing it to sparseTensor leads to errors (it doesn't have to have Strided Layout). I found here that dense tensor should work.
Idk if it can help but a strange thing is that the behaviors of my different runs only start to diverge after a some iterations as seen in the picture (I checked with a precision of 10^-16 and the 12-15 first loss value are the exact same )
Yes, graph.x
should be a dense tensor.
It's hard for me to say what might have gone wrong here. There might be other sources of non-determinism (e.g., differently sampled mini-batches from step 70 onwards).
❓ Questions & Help
I try to repeat my work, but find the
scatter
in torch_scatter (cuda) is un-stable, though with defined random seed.Due to the
scatter
is in classMessagePassing
, I thought It is worth paying attention to.Or I made mistake or neglected someting?
The following are my test results.
I'd appreciate it, if anyone could help or idea.
file.md