pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.35k stars 3.66k forks source link

[Roadmap] CPU Performance Optimization for PyG #4891

Open mingfeima opened 2 years ago

mingfeima commented 2 years ago

🚀 The feature, motivation and pitch

The goal of this roadmap is to optimize CPU performance for PyG (including torch_scatter, torch_sparse).

For the first step, we will start with single node inference performance optimization on:

Next step will extend to optimization effort to (distributed) training.

Performance Profiling

CPU platform: Icelake Xeon

Generic benchmarking

Large dataset benchmarking

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...        71.27%      608.842s        71.39%      609.891s      70.223ms          8685  
                                torch_sparse::spmm_mean        14.91%      127.390s        14.93%      127.522s       7.342ms         17370  
                                            aten::addmm         3.77%       32.166s         7.34%       62.727s       1.806ms         34740  
                                            aten::copy_         3.60%       30.766s         3.60%       30.766s     161.007us        191082  
                                               aten::mm         2.29%       19.588s         2.30%       19.683s       1.133ms         17370  
                                aten::native_batch_norm         0.94%        7.989s         1.01%        8.657s     332.256us         26055  

DataLoader (with preprocess of input data) is the major bottleneck here, mostly from_numpy (246s) and to (169s) triggered by data type conversion, source from convert_batch.

Performance Hotspots

Python level API upgrade in model scripts

The DataLoader is a major hotspot so the first step is to upgrade DataLoader from NeightborSampler to NeighborLoader which has native C++ impelemtation:

Native level kernel optimization

Phase One Optimizations

the current impl for scatter_add will try to parallel on the inner dimension to avoid write conflict; while ideally we should try to parallel on the outer dimension and vectorize on the inner dimension, yet need to resolve the write conflict on the output tensor. Experiment different impls for the given input range.

Phase Two Optimizations

Design option for vectorization

To vectorize kernels from torch-sparse and torch-scatter, we have multiple options:

(current decision is to go with option 3 as much as we can)

Bfloat16 enabling in torch-sparse/torch-scatter

(highly related to the vectorization method choosn)

Validation

rusty1s commented 2 years ago

Thanks for this detailed issue and roadmap. PyTorch recently released torch.scatter_reduce as well. As such, the long-term goal is to move to the PyTorch implementation of torch.scatter_reduce routines, and current optimizations of torch-scatter are properly not future-proof as a result. Can we also benchmark torch.scatter_reduce and torch_scatter to see if there is already a performance gain by simply switching the implementation?

mingfeima commented 2 years ago

Ok, I see. Then it is better to optimize scatter_reduce in torch. Just checked the code, scatter_add and scatter_reduce share the same kernel in torch so they have the same performance issues. Will have it fixed.

Padarn commented 2 years ago

hey @mingfeima this issue is great, a lot of detail.

NeighborLoader parallelization: the current impl is sequential (probably to avoid oversubscription with multiple workers on the data loader). Unlike GPU runs, asynchronously run data loading thread and computation thread does not always make sense. On some occasions, run data loading step and computation step sequentially while making each of the torch operator parallel on OpenMP (which is case of intra-parallelism) makes more sense.

I'm a complete newbie to this, so my question is to learn not suggest something. Can you explain what you're intending to change here? It sounds like you want to keep it sequential but maybe the actual sampling itself parallel?

mingfeima commented 2 years ago

I'm a complete newbie to this, so my question is to learn not suggest something. Can you explain what you're intending to change here? It sounds like you want to keep it sequential but maybe the actual sampling itself parallel?

Yes, that's the idea! The data loader from pytorch is more suitable for GPU (by setting num_workers=N, it will launch data loading threads asynchronously with the main computation thread). On CPU, it is probably better to run data loading and computation in sequential while parallel the sampler from in the data loader with OpenMP threads.

Padarn commented 2 years ago

That makes complete sense. Let me know if you'd like any help (I'd likely not be quick though 😅)

rusty1s commented 2 years ago

@mingfeima I assume the benchmarks have been run with num_workers=0? This explains why this is a bottleneck. Can you share some insights on how an OpenMP implementation of sampling behaves in relation to num_workers>0? Is it expected that this will potentially slow down the code compared to a single threaded implementation that utilizes parallelism solely on the worker level?

mingfeima commented 2 years ago

current benchmark profiling result uses the default setting. Some scripts, for example to_hetero_mag would explicitly set the num_workers, if not the pytorch default setting will be 4.

DataLoader time in the benchmark profile result actually comprises of two parts:

The second part takes more time, so it is still possible to be improved with single worker + parallel openmp. If we use num_workers>0, need to make sure openmp in the worker have correct setting (omp_num_threads and core affinity binding) to avoid over-subscription.

Actually the data loader optimization is a rather complexed issue, perhaps more complexed than optimizing the kernels :( since it is more likely a tuning job to achieve the most balanced situation between workload payload (memory footprint, computation complexity etc.) and hardware capacity (IO, memory bandwidth, ALU flops).

Usually we do not do data loading optimizations since the real case in deployment would probably be even more complexed (some venders have mechanisms like prefetching, batching to improve overall user experience and efficiency). But the thing is DGL has done some optimizations here so we need to at least something similar, otherwise out of box performance on PyG would look bad.

Anyway, we will make sure that openmp have correct settings either num_workers=0 or num_workers=N, and also each of the sampler can be properly paralleled. num_workers=0 benefits more for the pre processing and num_workers=N benefits more for the IO. And let the users to decide which way to go (maybe we can give a BKM or some simple guideline).

mingfeima commented 2 years ago

Updates on scatter_add optimizations, PR submitted at https://github.com/pytorch/pytorch/pull/82703

Initiative

Depending type of the edge_index, message passing will choose different paths: a) scatter_add for dense tensor; b) spmm for SparseTensor. The principal factor here is the memory format: While in the 1st case, the memory format for edge_index is COO and in 2nd case it is CSR.

Problem description

scatter_add is used to aggregate info in rowwise which means the index tensor is extended ( all rows have identical value).

A typical input shape for the dataset of reddit looks like:

self.sizes(): [135361, 256]; self.strides(): [256, 1]
index.sizes(): [477263, 256]; index.strides(): [1, 0]
src.sizes(): [477263, 256]; src.strides(): [256, 1]

So we pick rows from 477k indices and update dst index in self, ideally we want to parallel on outer dimension like 477k or 135k, but the scatter pattern indicate writes have conflicts among threads. The current ATen kernel choose to parallel on inner dimension of 256, which is not performant for the pyg usage: a) per thread memory access is non-contiguous; b) unable to be vectorized.

Algorithm

There exists a couple of algorithms to fix the write conflict such as: a) sorting; b) segment mutex; c) atomic; d) shared buffer, ... I choose a) sorting based on the input shape range which should be most performant. IFF anyone come up with better idea, please let me know :)

So,

Result

I used the SAGE+reddit from https://github.com/pyg-team/pytorch_geometric/blob/master/examples/reddit.py (Reason that I pick this one is that it used NeighborLoader which means I don't have to take data loader optimization into account for this one)

For inference, on ICX Xeon single socket, 20 cores @2.50GHz. End to end inference time reduced from 77.135s to 44.622s. Attach part of the profiling logs, as we can see scatter_add reduced from 37.797s to 6.454s.

  1. before
    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                                     aten::scatter_add_        49.00%       37.797s        49.00%       37.797s      41.445ms           912
                                     aten::index_select        19.74%       15.223s        19.74%       15.227s       6.678ms          2280
                                           aten::linear         0.01%       5.706ms        15.04%       11.602s      12.721ms           912
                                            aten::addmm         6.62%        5.108s         7.92%        6.112s      13.403ms           456
                                           aten::matmul         0.00%       2.339ms         7.10%        5.475s      12.006ms           456
                                               aten::mm         7.09%        5.472s         7.09%        5.472s      12.001ms           456
                                            aten::index         5.89%        4.544s         5.90%        4.549s       9.845ms           462
                                            aten::fill_         3.59%        2.768s         3.59%        2.768s       2.014ms          1374
                                            aten::zeros         0.01%       7.616ms         3.44%        2.652s       1.936ms          1370
                                            aten::zero_         0.00%       2.728ms         3.42%        2.636s       1.924ms          1370
  2. after
    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                                     aten::index_select        32.26%       14.395s        32.27%       14.398s       6.315ms          2280
                                           aten::linear         0.01%       6.329ms        26.19%       11.688s      12.815ms           912
                                     aten::scatter_add_        14.46%        6.454s        14.46%        6.454s       7.077ms           912
                                            aten::addmm        11.71%        5.223s        13.58%        6.060s      13.289ms           456
                                           aten::matmul         0.01%       2.257ms        12.58%        5.612s      12.307ms           456
                                               aten::mm        12.57%        5.610s        12.57%        5.610s      12.302ms           456
                                            aten::index         9.98%        4.453s         9.99%        4.456s       9.646ms           462
                                            aten::fill_         5.62%        2.506s         5.62%        2.506s       1.369ms          1830
                                            aten::zeros         0.02%       7.091ms         5.53%        2.466s       1.800ms          1370
                                            aten::zero_         0.01%       2.886ms         5.50%        2.453s       1.790ms          1370
                                     aten::true_divide_         0.02%       7.360ms         4.72%        2.106s       4.618ms           456

There still some TODOs to follow up which will bring some additional performance improvement:

rusty1s commented 2 years ago

Just to understand: Does this mean that we first sort index and then do a segment reduction? In that case it might be good to preserve the information that index is sorted such that we do not have this overhead in consecutive GNN layers.

mingfeima commented 2 years ago

Just to understand: Does this mean that we first sort index and then do a segment reduction? In that case it might be good to preserve the information that index is sorted such that we do not have this overhead in consecutive GNN layers.

Yes, that's the idea. The overhead is not only sorting, also we have to calculate the row_ptr indices... So the index should be constant? since it is an attribute from the dataset. If we can cache the sorted index, scatter add performance could be further improved by roughly 1/3.

rusty1s commented 2 years ago

As far as I understand, this would refer to a segment_add implementation, correct? Similar to the one present in torch-scatter.segment. Is there also a chance we can optimize scatter_add without relying on sorting?

mingfeima commented 2 years ago

Yes, the current scatter_add is kind of like sorting + segment_add, and both parts are properly paralleled.

Because of the semantics limitation, we can not skip sorting since from PyTorch side there is no guarantee that index is in ascending order.

As for the optimization techniques of scatter_add, i tried a couple of methods: a) sorting (current submitted PR used this approach); b) mutex on a block of the write addresses; c) atomic on the most inner dimension; so on and so on. My experiment shows that a) performs best right now... Since sorting also helps to increase cache locality and we can enable blocking on the nnz dimensions (so it would be only one write for each row and multiple reads for src). For b) we can not do blocking on writes so there would be multiple writes as well; c) is only suitable for some inner_size, e.g. src: [135K, 1] and index: [477K, 1], atomic comes with higher price than normal FMA.

The overhead is actually not mainly from the sorting itself but from memory allocation, i will switch to c10 cpu allocator to see if it helps.

Aside from that, probably we have a chance to cache the sorted index so as to save sorting for consecutive layers. Maybe add a attribute called edge_index_sorted and for the 1st layer we fill it with the sorted index and in the consecutive layers we can directly use segment_add. (Of course also need to make sure that segment_add is fully optimized)

Only a rough idea at the moment, my point is that we firstly clear the performance bottlenecks from torch (so we optimize scatter_add as what it is and no upper level API change) and then seek more optimization oppotunities from pyg/torch-scatter/torch-sparse side, where we can make more aggressive optimizations.

rusty1s commented 2 years ago

Got it, thanks for clarifying!

mingfeima commented 2 years ago

Update on spmm optimizations, PR submitted at https://github.com/pytorch/pytorch/pull/83727.

Port spmm reduction from torch-sparse to torch, the current PR is only for demonstrating performance gains, API definition needs more amendment.

Now only sum is added, more will come in future (max, mean, min), the algorithm is pretty much the same.

Select benchmark from ./ogb/examples/nodeproppred/products/gnn.py, since originally this one spent majority of time on torch_sparse::spmm_sum. The spmm roughly got 5x speedup on my 20 core machine.

  1. before
    -----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
    -----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
       torch_sparse::spmm_sum        97.09%       56.086s        97.09%       56.088s        6.232s             9
                 aten::linear         0.00%      85.000us         1.38%     795.485ms      88.387ms             9
                 aten::matmul         0.00%      57.000us         1.38%     795.260ms      88.362ms             9
                     aten::mm         1.38%     795.201ms         1.38%     795.203ms      88.356ms             9
                   aten::relu         0.00%      50.000us         0.76%     440.434ms      73.406ms             6
              aten::clamp_min         0.76%     440.384ms         0.76%     440.384ms      73.397ms             6
                   aten::add_         0.57%     327.801ms         0.57%     327.801ms      36.422ms             9
            aten::log_softmax         0.00%      23.000us         0.10%      55.503ms      18.501ms             3
           aten::_log_softmax         0.10%      55.480ms         0.10%      55.480ms      18.493ms             3
                 aten::argmax         0.09%      53.149ms         0.09%      53.153ms      13.288ms             4
                  aten::index         0.01%       5.771ms         0.01%       5.839ms     324.389us            18
                  aten::empty         0.00%       1.088ms         0.00%       1.088ms      77.714us            14
  2. after
    -----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
    -----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
               aten::spmm_sum        87.35%       11.826s        87.36%       11.827s        1.314s             9
                 aten::linear         0.00%      92.000us         5.87%     794.451ms      88.272ms             9
                 aten::matmul         0.00%      62.000us         5.87%     794.208ms      88.245ms             9
                     aten::mm         5.87%     794.143ms         5.87%     794.146ms      88.238ms             9
                   aten::relu         0.00%      53.000us         3.35%     452.977ms      75.496ms             6
              aten::clamp_min         3.35%     452.924ms         3.35%     452.924ms      75.487ms             6
                   aten::add_         2.58%     348.663ms         2.58%     348.663ms      38.740ms             9
                 aten::argmax         0.42%      57.473ms         0.42%      57.475ms      14.369ms             4
            aten::log_softmax         0.00%      22.000us         0.39%      52.605ms      17.535ms             3
           aten::_log_softmax         0.39%      52.583ms         0.39%      52.583ms      17.528ms             3
                  aten::index         0.04%       5.100ms         0.04%       5.174ms     287.444us            18
                  aten::empty         0.01%       1.097ms         0.01%       1.097ms      78.357us            14

To break down the optimization scheme a little bit:

The balanced thread partition is targeting at balancing the thread payload. Basically if we directly parallel on row direction, it will be (I collect number of edges for each thread):

### thread: 0; min: 1; max: 17482; avg = 172.599
### thread: 1; min: 1; max: 9918; avg = 137.251
### thread: 2; min: 1; max: 5786; avg = 39.7606
### thread: 3; min: 1; max: 4062; avg = 40.0852
### thread: 4; min: 1; max: 10406; avg = 39.7207
### thread: 5; min: 1; max: 3491; avg = 40.0985
### thread: 6; min: 1; max: 5965; avg = 40.0117
### thread: 7; min: 1; max: 5865; avg = 40.3841
### thread: 8; min: 1; max: 5892; avg = 39.969
### thread: 9; min: 1; max: 6076; avg = 39.9995
### thread: 10; min: 1; max: 5215; avg = 40.0757
### thread: 11; min: 1; max: 3893; avg = 40.1075
### thread: 12; min: 1; max: 8052; avg = 39.8108
### thread: 13; min: 1; max: 4062; avg = 39.7186
### thread: 14; min: 1; max: 3243; avg = 40.3022
### thread: 15; min: 1; max: 5008; avg = 40.4213
### thread: 16; min: 1; max: 7657; avg = 40.0987
### thread: 17; min: 1; max: 6784; avg = 40.0618
### thread: 18; min: 1; max: 4810; avg = 39.8836
### thread: 19; min: 1; max: 6429; avg = 39.9829

We can see that the first 2 threads have more payload than others, need to balance the thread payload here. Normally we can use dynamic scheduling for omp, but this won't fit into pytorch's at::parallel_for which is essentially a static scheduling, so I did manual partitioning here (the logic may be further refined, will do later).

mingfeima commented 2 years ago

Update on Training optimization

Optimize torch.gather for the classic pyg use case (index tensor is broadcasted), this will be the backward for scatter in training, https://github.com/pytorch/pytorch/pull/87586

When the index tensor is broadcasted along the last dimension, we can parallel on the outer dimension and vectorize on the inner dimension, which is similar to torch.index_select. Compared to scatter, this one is much easier to write.

GCN-Reddit (before)

GCN-Reddit-1-16-before

GCN-Reddit (after)

GCN-Reddit-1-16-update-gather

mingfeima commented 1 year ago

Sort out the 2nd stage optimization a little bit:

sampled_addmm COO: this operator is writing to a COO so need to make sure it is coalesced to make it parallel. And coalesced() is very slow right now. I measured on ogbn-products: