Open mingfeima opened 2 years ago
Thanks for this detailed issue and roadmap. PyTorch recently released torch.scatter_reduce
as well. As such, the long-term goal is to move to the PyTorch implementation of torch.scatter_reduce
routines, and current optimizations of torch-scatter
are properly not future-proof as a result. Can we also benchmark torch.scatter_reduce
and torch_scatter
to see if there is already a performance gain by simply switching the implementation?
Ok, I see. Then it is better to optimize scatter_reduce in torch. Just checked the code, scatter_add and scatter_reduce share the same kernel in torch so they have the same performance issues. Will have it fixed.
hey @mingfeima this issue is great, a lot of detail.
NeighborLoader
parallelization: the current impl is sequential (probably to avoid oversubscription with multiple workers on the data loader). Unlike GPU runs, asynchronously run data loading thread and computation thread does not always make sense. On some occasions, run data loading step and computation step sequentially while making each of the torch operator parallel on OpenMP (which is case of intra-parallelism) makes more sense.
I'm a complete newbie to this, so my question is to learn not suggest something. Can you explain what you're intending to change here? It sounds like you want to keep it sequential but maybe the actual sampling itself parallel?
I'm a complete newbie to this, so my question is to learn not suggest something. Can you explain what you're intending to change here? It sounds like you want to keep it sequential but maybe the actual sampling itself parallel?
Yes, that's the idea! The data loader from pytorch is more suitable for GPU (by setting num_workers=N, it will launch data loading threads asynchronously with the main computation thread). On CPU, it is probably better to run data loading and computation in sequential while parallel the sampler from in the data loader with OpenMP threads.
That makes complete sense. Let me know if you'd like any help (I'd likely not be quick though 😅)
@mingfeima I assume the benchmarks have been run with num_workers=0
? This explains why this is a bottleneck. Can you share some insights on how an OpenMP implementation of sampling behaves in relation to num_workers>0
? Is it expected that this will potentially slow down the code compared to a single threaded implementation that utilizes parallelism solely on the worker level?
current benchmark profiling result uses the default setting. Some scripts, for example to_hetero_mag would explicitly set the num_workers
, if not the pytorch default setting will be 4.
DataLoader
time in the benchmark profile result actually comprises of two parts:
The second part takes more time, so it is still possible to be improved with single worker + parallel openmp. If we use num_workers>0
, need to make sure openmp in the worker have correct setting (omp_num_threads and core affinity binding) to avoid over-subscription.
Actually the data loader optimization is a rather complexed issue, perhaps more complexed than optimizing the kernels :( since it is more likely a tuning job to achieve the most balanced situation between workload payload (memory footprint, computation complexity etc.) and hardware capacity (IO, memory bandwidth, ALU flops).
Usually we do not do data loading optimizations since the real case in deployment would probably be even more complexed (some venders have mechanisms like prefetching, batching to improve overall user experience and efficiency). But the thing is DGL has done some optimizations here so we need to at least something similar, otherwise out of box performance on PyG would look bad.
Anyway, we will make sure that openmp have correct settings either num_workers=0
or num_workers=N
, and also each of the sampler can be properly paralleled. num_workers=0
benefits more for the pre processing and num_workers=N
benefits more for the IO. And let the users to decide which way to go (maybe we can give a BKM or some simple guideline).
Updates on scatter_add
optimizations, PR submitted at https://github.com/pytorch/pytorch/pull/82703
Depending type of the edge_index
, message passing will choose different paths: a) scatter_add
for dense tensor; b) spmm
for SparseTensor
. The principal factor here is the memory format: While in the 1st case, the memory format for edge_index
is COO and in 2nd case it is CSR.
scatter_add
is used to aggregate info in rowwise which means the index
tensor is extended ( all rows have identical value).
A typical input shape for the dataset of reddit looks like:
self.sizes(): [135361, 256]; self.strides(): [256, 1]
index.sizes(): [477263, 256]; index.strides(): [1, 0]
src.sizes(): [477263, 256]; src.strides(): [256, 1]
So we pick rows from 477k indices and update dst index in self
, ideally we want to parallel on outer dimension like 477k or 135k, but the scatter
pattern indicate writes have conflicts among threads. The current ATen kernel choose to parallel on inner dimension of 256, which is not performant for the pyg usage: a) per thread memory access is non-contiguous; b) unable to be vectorized.
There exists a couple of algorithms to fix the write conflict such as: a) sorting; b) segment mutex; c) atomic; d) shared buffer, ... I choose a) sorting based on the input shape range which should be most performant. IFF anyone come up with better idea, please let me know :)
So,
index
to CSR format, using paralleled radix sortspmm
reductionI used the SAGE+reddit from https://github.com/pyg-team/pytorch_geometric/blob/master/examples/reddit.py
(Reason that I pick this one is that it used NeighborLoader
which means I don't have to take data loader optimization into account for this one)
For inference, on ICX Xeon single socket, 20 cores @2.50GHz. End to end inference time reduced from 77.135s
to 44.622s
.
Attach part of the profiling logs, as we can see scatter_add
reduced from 37.797s
to 6.454s
.
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912
aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280
aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912
aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456
aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456
aten::mm 7.09% 5.472s 7.09% 5.472s 12.001ms 456
aten::index 5.89% 4.544s 5.90% 4.549s 9.845ms 462
aten::fill_ 3.59% 2.768s 3.59% 2.768s 2.014ms 1374
aten::zeros 0.01% 7.616ms 3.44% 2.652s 1.936ms 1370
aten::zero_ 0.00% 2.728ms 3.42% 2.636s 1.924ms 1370
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::index_select 32.26% 14.395s 32.27% 14.398s 6.315ms 2280
aten::linear 0.01% 6.329ms 26.19% 11.688s 12.815ms 912
aten::scatter_add_ 14.46% 6.454s 14.46% 6.454s 7.077ms 912
aten::addmm 11.71% 5.223s 13.58% 6.060s 13.289ms 456
aten::matmul 0.01% 2.257ms 12.58% 5.612s 12.307ms 456
aten::mm 12.57% 5.610s 12.57% 5.610s 12.302ms 456
aten::index 9.98% 4.453s 9.99% 4.456s 9.646ms 462
aten::fill_ 5.62% 2.506s 5.62% 2.506s 1.369ms 1830
aten::zeros 0.02% 7.091ms 5.53% 2.466s 1.800ms 1370
aten::zero_ 0.01% 2.886ms 5.50% 2.453s 1.790ms 1370
aten::true_divide_ 0.02% 7.360ms 4.72% 2.106s 4.618ms 456
There still some TODOs to follow up which will bring some additional performance improvement:
can_use_32bit_index
index
and src
are 1d, aka. inner_size is 1max
, min
, mean
gather
which will be used for trainingJust to understand: Does this mean that we first sort index
and then do a segment reduction? In that case it might be good to preserve the information that index
is sorted such that we do not have this overhead in consecutive GNN layers.
Just to understand: Does this mean that we first sort
index
and then do a segment reduction? In that case it might be good to preserve the information thatindex
is sorted such that we do not have this overhead in consecutive GNN layers.
Yes, that's the idea. The overhead is not only sorting, also we have to calculate the row_ptr indices...
So the index should be constant? since it is an attribute from the dataset.
If we can cache the sorted
index, scatter add performance could be further improved by roughly 1/3.
As far as I understand, this would refer to a segment_add
implementation, correct? Similar to the one present in torch-scatter.segment
. Is there also a chance we can optimize scatter_add
without relying on sorting?
Yes, the current scatter_add
is kind of like sorting
+ segment_add
, and both parts are properly paralleled.
Because of the semantics limitation, we can not skip sorting since from PyTorch side there is no guarantee that index
is in ascending order.
As for the optimization techniques of scatter_add
, i tried a couple of methods: a) sorting (current submitted PR used this approach); b) mutex on a block of the write addresses; c) atomic on the most inner dimension; so on and so on. My experiment shows that a) performs best right now... Since sorting also helps to increase cache locality and we can enable blocking on the nnz dimensions (so it would be only one write for each row and multiple reads for src). For b) we can not do blocking on writes so there would be multiple writes as well; c) is only suitable for some inner_size, e.g. src: [135K, 1] and index: [477K, 1], atomic comes with higher price than normal FMA.
The overhead is actually not mainly from the sorting itself but from memory allocation, i will switch to c10 cpu allocator to see if it helps.
Aside from that, probably we have a chance to cache the sorted index so as to save sorting for consecutive layers. Maybe add a attribute called edge_index_sorted
and for the 1st layer we fill it with the sorted index and in the consecutive layers we can directly use segment_add
. (Of course also need to make sure that segment_add
is fully optimized)
Only a rough idea at the moment, my point is that we firstly clear the performance bottlenecks from torch (so we optimize scatter_add
as what it is and no upper level API change) and then seek more optimization oppotunities from pyg/torch-scatter/torch-sparse side, where we can make more aggressive optimizations.
Got it, thanks for clarifying!
Update on spmm
optimizations, PR submitted at https://github.com/pytorch/pytorch/pull/83727.
Port spmm
reduction from torch-sparse
to torch
, the current PR is only for demonstrating performance gains, API definition needs more amendment.
Now only sum
is added, more will come in future (max, mean, min), the algorithm is pretty much the same.
Select benchmark from ./ogb/examples/nodeproppred/products/gnn.py
, since originally this one spent majority of time on torch_sparse::spmm_sum
. The spmm
roughly got 5x speedup on my 20 core machine.
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
torch_sparse::spmm_sum 97.09% 56.086s 97.09% 56.088s 6.232s 9
aten::linear 0.00% 85.000us 1.38% 795.485ms 88.387ms 9
aten::matmul 0.00% 57.000us 1.38% 795.260ms 88.362ms 9
aten::mm 1.38% 795.201ms 1.38% 795.203ms 88.356ms 9
aten::relu 0.00% 50.000us 0.76% 440.434ms 73.406ms 6
aten::clamp_min 0.76% 440.384ms 0.76% 440.384ms 73.397ms 6
aten::add_ 0.57% 327.801ms 0.57% 327.801ms 36.422ms 9
aten::log_softmax 0.00% 23.000us 0.10% 55.503ms 18.501ms 3
aten::_log_softmax 0.10% 55.480ms 0.10% 55.480ms 18.493ms 3
aten::argmax 0.09% 53.149ms 0.09% 53.153ms 13.288ms 4
aten::index 0.01% 5.771ms 0.01% 5.839ms 324.389us 18
aten::empty 0.00% 1.088ms 0.00% 1.088ms 77.714us 14
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::spmm_sum 87.35% 11.826s 87.36% 11.827s 1.314s 9
aten::linear 0.00% 92.000us 5.87% 794.451ms 88.272ms 9
aten::matmul 0.00% 62.000us 5.87% 794.208ms 88.245ms 9
aten::mm 5.87% 794.143ms 5.87% 794.146ms 88.238ms 9
aten::relu 0.00% 53.000us 3.35% 452.977ms 75.496ms 6
aten::clamp_min 3.35% 452.924ms 3.35% 452.924ms 75.487ms 6
aten::add_ 2.58% 348.663ms 2.58% 348.663ms 38.740ms 9
aten::argmax 0.42% 57.473ms 0.42% 57.475ms 14.369ms 4
aten::log_softmax 0.00% 22.000us 0.39% 52.605ms 17.535ms 3
aten::_log_softmax 0.39% 52.583ms 0.39% 52.583ms 17.528ms 3
aten::index 0.04% 5.100ms 0.04% 5.174ms 287.444us 18
aten::empty 0.01% 1.097ms 0.01% 1.097ms 78.357us 14
To break down the optimization scheme a little bit:
The balanced thread partition
is targeting at balancing the thread payload. Basically if we directly parallel on row direction, it will be (I collect number of edges for each thread):
### thread: 0; min: 1; max: 17482; avg = 172.599
### thread: 1; min: 1; max: 9918; avg = 137.251
### thread: 2; min: 1; max: 5786; avg = 39.7606
### thread: 3; min: 1; max: 4062; avg = 40.0852
### thread: 4; min: 1; max: 10406; avg = 39.7207
### thread: 5; min: 1; max: 3491; avg = 40.0985
### thread: 6; min: 1; max: 5965; avg = 40.0117
### thread: 7; min: 1; max: 5865; avg = 40.3841
### thread: 8; min: 1; max: 5892; avg = 39.969
### thread: 9; min: 1; max: 6076; avg = 39.9995
### thread: 10; min: 1; max: 5215; avg = 40.0757
### thread: 11; min: 1; max: 3893; avg = 40.1075
### thread: 12; min: 1; max: 8052; avg = 39.8108
### thread: 13; min: 1; max: 4062; avg = 39.7186
### thread: 14; min: 1; max: 3243; avg = 40.3022
### thread: 15; min: 1; max: 5008; avg = 40.4213
### thread: 16; min: 1; max: 7657; avg = 40.0987
### thread: 17; min: 1; max: 6784; avg = 40.0618
### thread: 18; min: 1; max: 4810; avg = 39.8836
### thread: 19; min: 1; max: 6429; avg = 39.9829
We can see that the first 2 threads have more payload than others, need to balance the thread payload here. Normally we can use dynamic scheduling for omp, but this won't fit into pytorch's at::parallel_for
which is essentially a static scheduling, so I did manual partitioning here (the logic may be further refined, will do later).
Optimize torch.gather
for the classic pyg use case (index
tensor is broadcasted), this will be the backward for scatter
in training, https://github.com/pytorch/pytorch/pull/87586
When the index
tensor is broadcasted along the last dimension, we can parallel on the outer dimension and vectorize on the inner dimension, which is similar to torch.index_select
. Compared to scatter
, this one is much easier to write.
Sort out the 2nd stage optimization a little bit:
sampled_addmm
on SparseCSR: https://github.com/pytorch/pytorch/pull/90978sampled_addm
on SparseCOO (canceled)ReduceTypes
: GNN would rely on a few operators who have similar ReduceTypes
, such as ScatterReduce
, SegmentReduce
, SampledReduce
, SpmmReduce
segment_reduce
with lengths
and offsets
sampled_reduce
sampled_addmm
COO: this operator is writing to a COO so need to make sure it is coalesced
to make it parallel. And coalesced()
is very slow right now. I measured on ogbn-products
:
🚀 The feature, motivation and pitch
The goal of this roadmap is to optimize CPU performance for PyG (including
torch_scatter
,torch_sparse
).For the first step, we will start with single node inference performance optimization on:
Next step will extend to optimization effort to (distributed) training.
Performance Profiling
CPU platform: Icelake Xeon
Generic benchmarking
torch_sparse::spmm_sum
96.04%)DataLoader
83.49%,aten::scatter_add_
8.47%)DataLoader
59.83%,aten::scatter_add_
24.76%)aten::scatter_add_
27.61%,DataLoader
25.70%,aten::index
20.26%)aten::scatter_add_
30.91%,torch_scatter::scatter_max
24.54%,aten::mm
10.34%,aten::index_select
6.71%) most of models under pytorch_geometric/benchmark/citation have similar behavior from performance perspective.aten::addmm
21.69%,aten::scatter_add_
20.60%,aten::index_select
13.48%,DataLoader
12.31%)orch_scatter::scatter_max
39.34%,torch_scatter::scatter_min
39.25%); need follow up: need to get scatter_reduce tensor shape/stride (similar issue as aten::scatteradd?)torch_scatter::scatter_max
66.91%,torch_cluster::knn
23.56%) sourcebenchmark/points/edge_cnn.py
torch_scatter::scatter_max
torch_scatter::scatter_max 53.61%,aten::index_select
21.73%,DataLoader
16.11%) source from https://github.com/pyg-team/pytorch_geometric/pull/4915Large dataset benchmarking
DataLoader (with preprocess of input data) is the major bottleneck here, mostly
from_numpy
(246s) andto
(169s) triggered by data type conversion, source from convert_batch.Performance Hotspots
NeighborSampler
).Python level API upgrade in model scripts
The DataLoader is a major hotspot so the first step is to upgrade DataLoader from
NeightborSampler
toNeighborLoader
which has native C++ impelemtation:Native level kernel optimization
Phase One Optimizations
NeighborLoader
parallelization: the current impl is sequential (probably to avoid oversubscription with multiple workers on the data loader). Unlike GPU runs, asynchronously run data loading thread and computation thread does not always make sense. On some occasions, run data loading step and computation step sequentially while making each of the torch operator parallel on OpenMP (which is case of intra-parallelism) makes more sense. Hotspot ontorch_sparse::neighbor_sample
.aten::sort
: GCN + ogbn-products spent roughly 1/3 time on sort in the preprocessing step (which is not covered during profiler result for the model inference), introduced by indexing from sparse tensor at gnn.py#L123. Root cause is aten::sort(dim) could only be paralleled on dimensions != dim, and the grain size is not correctly set. Fixed by #74897.spmm_{max|mean|sum}
(torch_sparse). Add vectorization and prefetch (indirect memory access) and apply blocking on M and K (if necessary).scatte_add
andscatter_max
(torch_scatter). Optimizedscatter_add
(with extended index) with #82703. Still need more polishing work.the current impl for
scatter_add
will try to parallel on the inner dimension to avoid write conflict; while ideally we should try to parallel on the outer dimension and vectorize on the inner dimension, yet need to resolve the write conflict on the output tensor. Experiment different impls for the given input range.index_select
, optimized via #76868.index
, directly optimizeindex
would be difficult, maybe we can change it to more performance ops likeindex_select
from NeighborLoader or customize its kernel from NeighborLoader.Phase Two Optimizations
scatter_reduce
segment_reduce
and align the reduce types between scatter_reduce, spmm_reduce, segment_reduce.TensorTypeId
ofCPU
andSparseCPU
.scatter_add
: cache the sorted index.knn
(torch_cluster), need follow up shape info to determine proper method to parallel the kernel. Evaluate knn fromoneAPI
dal.Design option for vectorization
To vectorize kernels from torch-sparse and torch-scatter, we have multiple options:
#pragma omp simd
and add a compiler flagmarch=skylake-avx512
but this won't apply bfloat16 (bfloat16 is a overload of uint16 and won't be vectorized properly by compiler)at::vec::Vectorized<scalar_t>
, this will apply to bfloat16 but we need to customize the cmake scripts to make it compatible with PyTorch's cpu build flags: _DEFAULT(scalar code), _AVX2 and _AVX512.at::vec::Vectorized<scalar_t>
will work without any change but need to move the operator from torch-sparse/torch-scatter to torch. Makes more sense for the fused kernel of GAS.(current decision is to go with option 3 as much as we can)
Bfloat16 enabling in torch-sparse/torch-scatter
(highly related to the vectorization method choosn)
Validation