Open tchaton opened 5 years ago
Hi,
applying scatter_*
to sorted indices is commonly known as segment_*
, which, e.g., TensorFlow supports but PyTorch does not. This is a good option to increase the speed of message passing in dense graphs. There is an issue deep down in PyTorch to add this functionality. For scatter_*
, sorted indices can however result in increasing runtimes due to usage of atomic operations, since more threads run the danger of writing to the same output.
But it could be possible to perform a scan on the compact given by the targets for each source, and then copy the scan to the source index. By doing so, the code could be a bit faster ?
Can you elaborate since I am not sure I understand.
@rusty1s,
About the compact: https://www.youtube.com/watch?v=GyYfg3ywONQ&list=PLAwxTw4SYaPnFKojVQrmyOGFCqHTxfdv2&index=170
And the scan: https://www.youtube.com/watch?v=_5sM-4ODXaA&list=PLAwxTw4SYaPnFKojVQrmyOGFCqHTxfdv2&index=141
I was thinking, we could filter by source, get the compact of targets associated to each source, perform a scan of this set using the chosen operator, and copy the last value in the associated source.
Best, Thomas Chaton.
I don't see how this is faster (provided that I understand you correctly):
sum()
or mean()
?You do not parallelize over the node/edge dimension anymore. I don't see why it won't be parallelized with this.
Scan (we don't need the downsweep part) is extremely efficient on gpus for scatter operator and allows a log(n) steps to find the mean | max |. Where your aggregation is sequential if I am right. You go to the next target, and check if superior to the scanned max | mean | add , etc.
Ok, I think I finally understand. You still parallelize over the complete edge dimension, but perform the scan only on the parts where neighboring indices match. This could very well be what segment_*
is doing internally.
@rusty1s,
I will give it a try in my free time to see if it brings anything and play with the HAG hierarchical paper too. I will keep you updated of my findings.
Best, Thomas Chaton.
Hey @rusty1s,
I have contacted the main author of HAG. He should grant me access to their code in the coming week. He also worked on this amazing DL graph optimation: https://github.com/jiazhihao/TASO
I am going to fork pytorch_scatter and work on it with some people. Feel free to help. If we have better performance (speed / memory), we will try to merge out.
Best, Thomas Chaton.
This sounds awesome! Please keep me updated. If you have any questions, feel free to reach out :)
Hey @rusty1s,
would you like to have access to the HAG code too ?
Best, Thomas Chaton
If this is possible, sure :)
Hey @rusty1s,
Here is the repo: https://github.com/jiazhihao/gnn. You should have an invitation to join. They are also going to work on integrating it within PYG backend. That's great news :)
Best, Thomas Chaton.
Hey @rusty1s,
Could you please draw me an interface of how you would like to use HAG within torch_scatter ? I started to work on it. I was thinking they could be two ways to use it.
And it will compute the HAG from the edges (will need both source / target instead of just "target" in current API) at every forward.
Best, Thomas Chaton
Hi,
I believe the scatter HAG algorithms should be implemented on its own, e.g., in a .hag
subpackage.
Then, there should be a method to precompute the HAG, which the scatter calls expect as an input.
Hey @rusty1s,
Yes, I was thinking about something like that. Here is the repo: https://github.com/tchaton/tsd
Best, Thomas Chaton.
❓ Questions & Help
Hey @rusty1s, I have started to read pytorch scatter cpp and gpu code. I might some questions as there is not much comments.
As the MessagePassing is doing both index_select and scatter, why don't you sort the edge_indexes to reduce jump in memory.
Example: if self.flow == "target_to_source": E = edge_index idx = np.lexsort((E[:, 0], E[:, 1])) # SORT BY SOURCE AND THEN BY TARGET E = E[idx] # Sorted edge_index -> Source is continuous and could be scattered using just an offset.
What do you think ?
Best, T.C