rusty1s / pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations
https://pytorch-scatter.readthedocs.io
MIT License
1.54k stars 179 forks source link

why not use 'torch.ops.torch_scatter.scatter_sum' instead of 'aten::scatter_add_'? #405

Closed JYXL1 closed 3 months ago

JYXL1 commented 10 months ago

the implementation of scatter_sum of scatter.py use aten::scatter_add_ , rather than use the implementation from itself, the same as scatter_mean, why?

rusty1s commented 10 months ago

We benchmarked this and somehow, scatter_add is faster than its scatter_reduce(reduce="sum") counterpart.

JYXL1 commented 10 months ago

i see. i have test it with v100-32G and get the same conclusion. however, when i test it with cpu, the result maybe confusing. i have tried two different kinds cpu and get results as follow: with Intel(R) Xeon(R) Gold 6248R CPU @ 3.00GHz

DIMACS10/citationCiteseer (avg row length: 8.62)
                            1            16             32            64
SCA2_ROW      0.00787   0.49399   1.72276   4.42197
SCA2_COL        0.01393   0.80683   2.36117   5.42819
SCA3_ROW      0.02172   0.61463   1.59856   3.90823
SCA3_COL        0.03550   0.84841   2.05098   4.41715

SNAP/web-Stanford (avg row length: 8.20):
                            1            16             32            64
SCA2_ROW       0.00751   0.49775   1.72447   4.44219
SCA2_COL        0.01589   0.81978    2.38535   5.51415
SCA3_ROW       0.02170   0.62473   1.62725   3.91508
SCA3_COL        0.03511   0.84524   2.05468    4.44080

Janna/StocF-1465 (avg row length: 14.34):
                            1             16            32            64
SCA2_ROW       0.08433   4.71143   15.44207  39.04459
SCA2_COL        0.24425   11.30899  28.03308  60.25658
SCA3_ROW       0.21467   5.67035   14.47458   35.05938
SCA3_COL        0.41748   10.68510   22.14663  45.80257

with Intel(R) Xeon(R) Gold 5318Y CPU @ 2.10GHz

DIMACS10/citationCiteseer (avg row length: 8.62):
                            1            16             32            64
SCA2_ROW      0.03260  0.66582   2.01713    5.89464
SCA2_COL        0.04508  1.27152   3.63535    8.23443
SCA3_ROW      0.03065   0.57203   1.50349    3.99253
SCA3_COL        0.04227   0.90501   2.35196    5.22992

SNAP/web-Stanford (avg row length: 8.20):
                            1            16             32            64
SCA2_ROW      0.03212   0.66318   2.01222   6.09568
SCA2_COL        0.04535   1.19622   3.51602   8.35593
SCA3_ROW      0.03066   0.56944   1.47491   4.04390
SCA3_COL        0.04281   0.86653   2.31969   5.34826

Janna/StocF-1465 (avg row length: 14.34):
                            1            16             32            64
SCA2_ROW      0.29553   6.02275    17.75847   49.62980
SCA2_COL        0.47834   21.06510  39.97363   78.55523
SCA3_ROW      0.32140    5.17531    13.06473   33.30777
SCA3_COL        0.44508   14.37845  28.00824   55.28940

where SCA2_ROW, SCA2_COL, SCA3_ROW, SCA3_COL defined as follow:

    def sca2_row(x):
        return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)

    def sca2_col(x):
        return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)

    def sca3_row(x):
        return torch.ops.torch_scatter.scatter_sum(x, row, 0, None, dim_size)

    def sca3_col(x):
        return torch.ops.torch_scatter.scatter_sum(x, row2, 0, None, dim_size)

my test version is 2.0.9 of torch_scatter. as a result, i guess it can benifit from your job a lot in some cpu, and with the size grows, more significant it can be seem

github-actions[bot] commented 3 months ago

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?