jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.54k stars 2.8k forks source link

Choice of ops for stencil computations run on TPUs #3341

Open hubertlu-tw opened 4 years ago

hubertlu-tw commented 4 years ago

Hi, I am currently investigating possibilities of using JAX for scientific computation on TPUs. Thanks to the excellent work in a JAX tutorial (https://github.com/google/jax/blob/master/cloud_tpu_colabs/Wave_Equation.ipynb), it helped me understand how to use JAX and its advantages more quickly. However, one of the questions I have is why the performance of using convolution-based ops for stencil computations is much worse than that of using the element-wise ops on Cloud TPU.

To take advantage of the compute power of MXU in TPU, I use two 1D convolution ops for the stencil computation in analogy to the element-wise ops. The following snippets are for 5-point stencil computations in 2D problems.

Element-wise ops:

left = shift(array, +1, axis=0)
right = shift(array, -1, axis=0)
up = shift(array, +1, axis=1)
down = shift(array, -1, axis=1)
convolved = (left + right + up + down - 4 * array)

Convolution-based ops:

col_F = make_kernel([[1., -2., 1.]])
row_F = make_kernel([[1.,],
                     [-2.,],
                     [1.]])
dn_col = lax.conv_dimension_numbers(array.shape, col_F.shape,('NHWC', 'HWIO', 'NHWC'))
dn_row = lax.conv_dimension_numbers(array.shape, row_F.shape,('NHWC', 'HWIO', 'NHWC'))
col_ops = lax.conv_general_dilated(array, col_F, (1,1),'SAME', (1,1), (1,1),  dn_col) 
row_ops = lax.conv_general_dilated(array, row_F, (1,1),'SAME', (1,1), (1,1),  dn_row) 
convolved = (col_ops+row_ops)[0,:,:,0]
jekbradbury commented 4 years ago

The MXU needs dot products of length at least 128 for full throughput; for a convolution, that's the product of all kernel size dimensions and the input feature count. In your case the dot product length is only 3, so you can use at best 3/128 of the MXU flops (and likely even less since you have low arithmetic intensity). Unrolled elementwise computations, as in your first snippet, are typically a better way to implement finite difference/stencil convolutions.

I'm not sure how well it works with JAX on Cloud TPU right now, but the Cloud TPU Profiler can be useful in figuring out how XLA compiles operations for the hardware, and how well they utilize MXU flops and memory bandwidth.