deepsphere / deepsphere-cosmo-tf2

A spherical convolutional neural network for cosmology (TFv2).
https://arxiv.org/abs/1810.12186
MIT License
19 stars 5 forks source link

Speedup by removing loops from Chebyshev polynomail computation #10

Open tomaszkacprzak opened 8 months ago

tomaszkacprzak commented 8 months ago

The polynomial values are computed in a loop:

https://github.com/deepsphere/deepsphere-cosmo-tf2/blob/3facedf3747abe3e065c45d3368d4d0b1ddf9a2d/deepsphere/gnn_layers.py#L127

For commonly used degree of 5, that requires 5 consecutive steps. It may be faster to execute this using a single sparse-dense-matmul. That would require building a sparse laplacian for each polynomial component, which can be pre-computed. The single sparse-dense-matmul can be launched once with input of sparse-dense-matmul( block sparse-L, tile input ) followed by a reshape and a sum. For N=polynomial degree and M=map length and ignoring channels and batch dimensions, the matrix multiplication could look like this

[ N*M x N*M ] is the block-sparse L
[ N*M x 1 ] is the input tiled N times
[ N*M x 1 ] = sparse-dense-matmul( [ N*M x N*M ], [ N*M x 1 ] )
[ N x M x 1 ] = reshape( [ N*M x 1 ], dims=[ N, M, 1 ] )

which gives N feature maps.

Any chance that would make sense?

Also tagging @Arne-Thomsen

jafluri commented 8 months ago

I agree that this would most likely boost the performance quite a bit. However, I am not sure what would happen to the memory usage. If I understand you correctly then the [ N*M x N*M ] block-sparse L contains blocks where the original L was multiplied with itself up to M times. Since the original L is a symmetric sparse matrix with a dense diagonal and ~20 entries per row, I'd expect the sparsity of the matrix to go done exponentially with each multiplication with itself. So the blocks corresponding to the higher degrees of the polynomial might become almost dense.

I am not entirely sure however, so I guess one could benchmark it.

Arne-Thomsen commented 8 months ago

I think the even bigger problem is that this goes directly against the recent pull request https://github.com/deepsphere/deepsphere-cosmo-tf2/pull/8: for the ~5000 square degrees and multiple input channels of DES, I ran into the int32 addressing limitation of tf.sparse.sparse_dense_matmul even for the current loop implementation, so I had to introduce another loop to be able to process these inputs.

It's a pity that tf.sparse.sparse_dense_matmul doesn't allow larger inputs.

tomaszkacprzak commented 8 months ago

Hey Janis!

What I have in mind is that the block sparse Laplacian would be a matrix containing N blocks on the diagonal, each corresponding to the Laplacian for the polynomial with order 1..N. We would build it with the explicit (not recursive) definition of Chebyshev polynomials (https://en.wikipedia.org/wiki/Chebyshev_polynomials#First_kind). So the total number of elements for this matix would be just N * count_nonzero(L). The sparse-dense matmul would be launched just once. Where do you expect the multiplication of Laplacian with itself? Is it in subsequent conv layers? Or am I missing something..

Arne: The 2^31 is indeed the problem, in that regime the block-sparse-L indeed won't help. But I think downstream in the network the feature lengths are much smaller, which is where the speed-up may come.

jafluri commented 8 months ago

It's been a while, but as far as I understand it, the Chebyshev polynomials are evaluated with the graph Laplacian and then applied on the graph, so you would evaluate $T_n(L)\cdot\mathbf{x}$, where $T_n$ is the polynomial, $L$ the Laplacian and $\mathbf{x}$ the graph with potentially many features. So you would need to calculate $L^n,\ n \in {1,\dots,m}$ to get your block matrix.

If the graph does not have too many nodes, it would probably be feasible even if the blocks are dense, so one could implement it and only trigger it for certain resolutions. I also believe that sparse matrices with dense blocks have now optimized instructions and should be even more efficient than a sparse dense multiplication.