jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

Sharding is much slower than pmap for while loops of varying length while loops #20968

Open lockwo opened 6 months ago

lockwo commented 6 months ago

Description

As the title indicated, with a double while loop, where the inner while loop may change in length over outer while loop steps, pmap is substantially faster than sharding. This may sound contrived, but is exactly what happens in other packages, such as diffrax where I first identified this issue: https://github.com/patrick-kidger/diffrax/issues/407. I believe there are two possibilities, 1) I am using sharding wrong and that is why it is slow (very possible, I am new to sharding), 2) something else is going on in sharding.

I have included a MVC below. I ran on both CPU and GPU and the results on GPU are even more noticeable.

import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils

def f(t, y, theta):
    return jnp.abs(jnp.sin(t)) + theta * y

def solve(init, key):
    def inner_loop_cond(state):
        t, y, _ = state
        return y.squeeze() < 10

    def inner_loop_body(state):
        t, y, theta = state
        dy = f(t, y, theta)
        return (t + 0.1, y + 0.1 * dy, theta)

    def outer_loop_cond(state):
        _, _, _, count = state
        return count < 5

    def outer_loop_body(state):
        t, y, theta, count = state
        y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
        new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
        return (new_t, new_y, theta, count + 1)

    inner_while_loop = jax.lax.while_loop
    outer_while_loop = jax.lax.while_loop
    theta = 5.0
    t_initial = 0.0
    y_initial = init
    count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
    final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
    return final_state[1]

batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()

inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])

x, y = jax.device_put((inits, keys), sharding)

fn = jax.jit(jax.vmap(solve))
pmap_fn = jax.pmap(fn)
# Ignore compilation time
_ = fn(x, y).block_until_ready()
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready()
%%timeit
_ = fn(x, y).block_until_ready()

CPU: 5.11 ms ± 53.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) GPU: 1.18 s ± 48.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready()

CPU: 251 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) GPU: 3.93 ms ± 225 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

System info (python version, jaxlib version, accelerator, etc.)

CPU:

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.14 | packaged by conda-forge | [Clang 16.0.6 ]
jax.devices (10 total, 10 local): [CpuDevice(id=0) CpuDevice(id=1) ... CpuDevice(id=8) CpuDevice(id=9)]
process_count: 1
platform: uname_result(system='Darwin', node=, release='23.3.0', version='Darwin Kernel Version 23.3.0; root:xnu-10002.81.5~7/RELEASE_ARM64_T6000', machine='arm64')

GPU:

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.12 [GCC 11.4.0]
jax.devices (3 total, 3 local): [cuda(id=0) cuda(id=1) cuda(id=2)]
process_count: 1
platform: uname_result(system='Linux', node=, release='5.15.0-89-generic', version='#99-Ubuntu SMP, machine='x86_64')

$ nvidia-smi
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:0A:00.0 Off |                    0 |
| N/A   43C    P0              76W / 400W |  62219MiB / 81920MiB |     20%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  | 00000000:4B:00.0 Off |                    0 |
| N/A   50C    P0              82W / 400W |  62349MiB / 81920MiB |     13%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM4-80GB          On  | 00000000:C3:00.0 Off |                    0 |
| N/A   42C    P0              72W / 400W |  62205MiB / 81920MiB |     14%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   1033933      C  python        62208MiB |
|    1   N/A  N/A   1033933      C  python        62338MiB |
|    2   N/A  N/A   1033933      C  python        62194MiB |
+---------------------------------------------------------------------------------------+
lockwo commented 6 months ago

To be extra clear, if you make count_initial = 0 you see the same speed for sharding and for pmap. Only when this varies per shard does this slow down

lockwo commented 6 months ago

Also, if you replace the body and do

new_t, new_y, theta = inner_loop_body((t, y, theta))
#new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
return (new_t, new_y, theta, count + 1)

you see the same speed, which indicates the second while loop is important to the slowdown

lockwo commented 5 months ago

@patrick-kidger you mentioned in https://github.com/patrick-kidger/diffrax/issues/407 that you suspect this is within XLA, do you have any advice on how to approach that? I haven't investigated an XLA system this complex before. Even my reduced complexity example (shown below) yields XLA's that are not exceedingly readable (shown further below). Is there a goto issue/piece of XLA/jax documentation on identifying whether a bug is in jax vs XLA and how to spot it?

import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils

def solve(init, key):
    def inner_loop_cond(state):
        t, y, _ = state
        return y.squeeze() < 2

    def inner_loop_body(state):
        t, y, theta = state
        return (t + 0.1, y + 0.1, theta)

    def outer_loop_cond(state):
        _, _, _, count = state
        return count < 5

    def outer_loop_body(state):
        t, y, theta, count = state
        y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
        new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
        return (new_t, new_y, theta, count + 1)

    inner_while_loop = jax.lax.while_loop
    outer_while_loop = jax.lax.while_loop
    theta = 5.0
    t_initial = 0.0
    y_initial = init
    count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
    final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
    return final_state[1]
HloModule xla_computation_solve, entry_computation_layout={(f32[10,1]{1,0}, u32[10,2]{1,0})->(f32[10,1]{1,0})}

clip.3 {
  Arg_2.6 = s32[] parameter(2)
  Arg_1.5 = s32[] parameter(1)
  Arg_0.4 = s32[] parameter(0)
  maximum.7 = s32[] maximum(Arg_1.5, Arg_0.4)
  ROOT minimum.8 = s32[] minimum(Arg_2.6, maximum.7)
}

clip_0.9 {
  Arg_2.12 = s32[] parameter(2)
  Arg_1.11 = s32[] parameter(1)
  Arg_0.10 = s32[] parameter(0)
  maximum.13 = s32[] maximum(Arg_1.11, Arg_0.10)
  ROOT minimum.14 = s32[] minimum(Arg_2.12, maximum.13)
}

clip_0.15 {
  Arg_2.18 = s32[] parameter(2)
  Arg_1.17 = s32[] parameter(1)
  Arg_0.16 = s32[] parameter(0)
  maximum.19 = s32[] maximum(Arg_1.17, Arg_0.16)
  ROOT minimum.20 = s32[] minimum(Arg_2.18, maximum.19)
}

region_0.21 {
  arg_tuple.22 = (s32[], s32[], u32[10,2]{1,0}, u32[10,2]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
  get-tuple-element.23 = s32[] get-tuple-element(arg_tuple.22), index=0
  constant.33 = s32[] constant(1)
  add.87 = s32[] add(get-tuple-element.23, constant.33)
  get-tuple-element.24 = s32[] get-tuple-element(arg_tuple.22), index=1
  add.34 = s32[] add(get-tuple-element.24, constant.33)
  get-tuple-element.25 = u32[10,2]{1,0} get-tuple-element(arg_tuple.22), index=2
  get-tuple-element.26 = u32[10,2]{1,0} get-tuple-element(arg_tuple.22), index=3
  add.37 = u32[10,2]{1,0} add(get-tuple-element.25, get-tuple-element.26)
  get-tuple-element.30 = u32[4]{0} get-tuple-element(arg_tuple.22), index=7
  slice.35 = u32[1]{0} slice(get-tuple-element.30), slice={[0:1]}
  reshape.36 = u32[] reshape(slice.35)
  broadcast.38 = u32[10,2]{1,0} broadcast(reshape.36), dimensions={}
  shift-left.39 = u32[10,2]{1,0} shift-left(get-tuple-element.26, broadcast.38)
  constant.32 = u32[] constant(32)
  subtract.40 = u32[] subtract(constant.32, reshape.36)
  broadcast.41 = u32[10,2]{1,0} broadcast(subtract.40), dimensions={}
  shift-right-logical.42 = u32[10,2]{1,0} shift-right-logical(get-tuple-element.26, broadcast.41)
  or.43 = u32[10,2]{1,0} or(shift-left.39, shift-right-logical.42)
  xor.44 = u32[10,2]{1,0} xor(add.37, or.43)
  add.47 = u32[10,2]{1,0} add(add.37, xor.44)
  slice.45 = u32[1]{0} slice(get-tuple-element.30), slice={[1:2]}
  reshape.46 = u32[] reshape(slice.45)
  broadcast.48 = u32[10,2]{1,0} broadcast(reshape.46), dimensions={}
  shift-left.49 = u32[10,2]{1,0} shift-left(xor.44, broadcast.48)
  subtract.50 = u32[] subtract(constant.32, reshape.46)
  broadcast.51 = u32[10,2]{1,0} broadcast(subtract.50), dimensions={}
  shift-right-logical.52 = u32[10,2]{1,0} shift-right-logical(xor.44, broadcast.51)
  or.53 = u32[10,2]{1,0} or(shift-left.49, shift-right-logical.52)
  xor.54 = u32[10,2]{1,0} xor(add.47, or.53)
  add.57 = u32[10,2]{1,0} add(add.47, xor.54)
  slice.55 = u32[1]{0} slice(get-tuple-element.30), slice={[2:3]}
  reshape.56 = u32[] reshape(slice.55)
  broadcast.58 = u32[10,2]{1,0} broadcast(reshape.56), dimensions={}
  shift-left.59 = u32[10,2]{1,0} shift-left(xor.54, broadcast.58)
  subtract.60 = u32[] subtract(constant.32, reshape.56)
  broadcast.61 = u32[10,2]{1,0} broadcast(subtract.60), dimensions={}
  shift-right-logical.62 = u32[10,2]{1,0} shift-right-logical(xor.54, broadcast.61)
  or.63 = u32[10,2]{1,0} or(shift-left.59, shift-right-logical.62)
  xor.64 = u32[10,2]{1,0} xor(add.57, or.63)
  add.67 = u32[10,2]{1,0} add(add.57, xor.64)
  get-tuple-element.27 = u32[10,1]{1,0} get-tuple-element(arg_tuple.22), index=4
  broadcast.75 = u32[10,1]{1,0} broadcast(get-tuple-element.27), dimensions={0,1}
  reshape.76 = u32[10]{0} reshape(broadcast.75)
  broadcast.77 = u32[10,2]{1,0} broadcast(reshape.76), dimensions={0}
  add.78 = u32[10,2]{1,0} add(add.67, broadcast.77)
  slice.65 = u32[1]{0} slice(get-tuple-element.30), slice={[3:4]}
  reshape.66 = u32[] reshape(slice.65)
  broadcast.68 = u32[10,2]{1,0} broadcast(reshape.66), dimensions={}
  shift-left.69 = u32[10,2]{1,0} shift-left(xor.64, broadcast.68)
  subtract.70 = u32[] subtract(constant.32, reshape.66)
  broadcast.71 = u32[10,2]{1,0} broadcast(subtract.70), dimensions={}
  shift-right-logical.72 = u32[10,2]{1,0} shift-right-logical(xor.64, broadcast.71)
  or.73 = u32[10,2]{1,0} or(shift-left.69, shift-right-logical.72)
  xor.74 = u32[10,2]{1,0} xor(add.67, or.73)
  get-tuple-element.28 = u32[10,1]{1,0} get-tuple-element(arg_tuple.22), index=5
  broadcast.79 = u32[10,1]{1,0} broadcast(get-tuple-element.28), dimensions={0,1}
  reshape.80 = u32[10]{0} reshape(broadcast.79)
  broadcast.81 = u32[10,2]{1,0} broadcast(reshape.80), dimensions={0}
  add.82 = u32[10,2]{1,0} add(xor.74, broadcast.81)
  add.83 = s32[] add(get-tuple-element.24, constant.33)
  convert.84 = u32[] convert(add.83)
  broadcast.85 = u32[10,2]{1,0} broadcast(convert.84), dimensions={}
  add.86 = u32[10,2]{1,0} add(add.82, broadcast.85)
  get-tuple-element.29 = u32[10,1]{1,0} get-tuple-element(arg_tuple.22), index=6
  get-tuple-element.31 = u32[4]{0} get-tuple-element(arg_tuple.22), index=8
  ROOT tuple.88 = (s32[], s32[], u32[10,2]{1,0}, u32[10,2]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(add.87, add.34, add.78, add.86, get-tuple-element.28, get-tuple-element.29, get-tuple-element.27, get-tuple-element.31, get-tuple-element.30)
}

region_1.89 {
  arg_tuple.90 = (s32[], s32[], u32[10,2]{1,0}, u32[10,2]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
  get-tuple-element.92 = s32[] get-tuple-element(arg_tuple.90), index=1
  get-tuple-element.93 = u32[10,2]{1,0} get-tuple-element(arg_tuple.90), index=2
  get-tuple-element.94 = u32[10,2]{1,0} get-tuple-element(arg_tuple.90), index=3
  get-tuple-element.95 = u32[10,1]{1,0} get-tuple-element(arg_tuple.90), index=4
  get-tuple-element.96 = u32[10,1]{1,0} get-tuple-element(arg_tuple.90), index=5
  get-tuple-element.97 = u32[10,1]{1,0} get-tuple-element(arg_tuple.90), index=6
  get-tuple-element.98 = u32[4]{0} get-tuple-element(arg_tuple.90), index=7
  get-tuple-element.99 = u32[4]{0} get-tuple-element(arg_tuple.90), index=8
  get-tuple-element.91 = s32[] get-tuple-element(arg_tuple.90), index=0
  constant.100 = s32[] constant(5)
  ROOT compare.101 = pred[] compare(get-tuple-element.91, constant.100), direction=LT
}

_threefry_split.102 {
  constant.106 = s32[] constant(0)
  iota.109 = u32[4]{0} iota(), iota_dimension=0
  slice.112 = u32[2]{0} slice(iota.109), slice={[0:2]}
  reshape.114 = u32[1,2]{1,0} reshape(slice.112)
  broadcast.118 = u32[1,2]{1,0} broadcast(reshape.114), dimensions={0,1}
  reshape.119 = u32[2]{0} reshape(broadcast.118)
  broadcast.120 = u32[10,2]{1,0} broadcast(reshape.119), dimensions={1}
  Arg_0.103 = u32[10,2]{1,0} parameter(0)
  slice.110 = u32[10,1]{1,0} slice(Arg_0.103), slice={[0:10], [0:1]}
  broadcast.121 = u32[10,1]{1,0} broadcast(slice.110), dimensions={0,1}
  reshape.122 = u32[10]{0} reshape(broadcast.121)
  broadcast.123 = u32[10,2]{1,0} broadcast(reshape.122), dimensions={0}
  add.124 = u32[10,2]{1,0} add(broadcast.120, broadcast.123)
  slice.113 = u32[2]{0} slice(iota.109), slice={[2:4]}
  reshape.115 = u32[1,2]{1,0} reshape(slice.113)
  broadcast.125 = u32[1,2]{1,0} broadcast(reshape.115), dimensions={0,1}
  reshape.126 = u32[2]{0} reshape(broadcast.125)
  broadcast.127 = u32[10,2]{1,0} broadcast(reshape.126), dimensions={1}
  slice.111 = u32[10,1]{1,0} slice(Arg_0.103), slice={[0:10], [1:2]}
  broadcast.128 = u32[10,1]{1,0} broadcast(slice.111), dimensions={0,1}
  reshape.129 = u32[10]{0} reshape(broadcast.128)
  broadcast.130 = u32[10,2]{1,0} broadcast(reshape.129), dimensions={0}
  add.131 = u32[10,2]{1,0} add(broadcast.127, broadcast.130)
  xor.116 = u32[10,1]{1,0} xor(slice.110, slice.111)
  constant.104 = u32[] constant(466688986)
  broadcast.105 = u32[10,1]{1,0} broadcast(constant.104), dimensions={}
  xor.117 = u32[10,1]{1,0} xor(xor.116, broadcast.105)
  constant.108 = u32[4]{0} constant({13, 15, 26, 6})
  constant.107 = u32[4]{0} constant({17, 29, 16, 24})
  tuple.132 = (s32[], s32[], u32[10,2]{1,0}, u32[10,2]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(constant.106, constant.106, add.124, add.131, slice.111, xor.117, slice.110, constant.108, constant.107)
  while.133 = (s32[], s32[], u32[10,2]{1,0}, u32[10,2]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) while(tuple.132), condition=region_1.89, body=region_0.21
  get-tuple-element.134 = s32[] get-tuple-element(while.133), index=0
  get-tuple-element.135 = s32[] get-tuple-element(while.133), index=1
  get-tuple-element.138 = u32[10,1]{1,0} get-tuple-element(while.133), index=4
  get-tuple-element.139 = u32[10,1]{1,0} get-tuple-element(while.133), index=5
  get-tuple-element.140 = u32[10,1]{1,0} get-tuple-element(while.133), index=6
  get-tuple-element.141 = u32[4]{0} get-tuple-element(while.133), index=7
  get-tuple-element.142 = u32[4]{0} get-tuple-element(while.133), index=8
  get-tuple-element.136 = u32[10,2]{1,0} get-tuple-element(while.133), index=2
  get-tuple-element.137 = u32[10,2]{1,0} get-tuple-element(while.133), index=3
  concatenate.143 = u32[10,4]{1,0} concatenate(get-tuple-element.136, get-tuple-element.137), dimensions={1}
  ROOT reshape.144 = u32[10,2,2]{2,1,0} reshape(concatenate.143)
}

region_2.145 {
  arg_tuple.146 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
  get-tuple-element.147 = s32[] get-tuple-element(arg_tuple.146), index=0
  constant.157 = s32[] constant(1)
  add.205 = s32[] add(get-tuple-element.147, constant.157)
  get-tuple-element.148 = s32[] get-tuple-element(arg_tuple.146), index=1
  add.158 = s32[] add(get-tuple-element.148, constant.157)
  get-tuple-element.149 = u32[10,1]{1,0} get-tuple-element(arg_tuple.146), index=2
  get-tuple-element.150 = u32[10,1]{1,0} get-tuple-element(arg_tuple.146), index=3
  add.161 = u32[10,1]{1,0} add(get-tuple-element.149, get-tuple-element.150)
  get-tuple-element.154 = u32[4]{0} get-tuple-element(arg_tuple.146), index=7
  slice.159 = u32[1]{0} slice(get-tuple-element.154), slice={[0:1]}
  reshape.160 = u32[] reshape(slice.159)
  broadcast.162 = u32[10,1]{1,0} broadcast(reshape.160), dimensions={}
  shift-left.163 = u32[10,1]{1,0} shift-left(get-tuple-element.150, broadcast.162)
  constant.156 = u32[] constant(32)
  subtract.164 = u32[] subtract(constant.156, reshape.160)
  broadcast.165 = u32[10,1]{1,0} broadcast(subtract.164), dimensions={}
  shift-right-logical.166 = u32[10,1]{1,0} shift-right-logical(get-tuple-element.150, broadcast.165)
  or.167 = u32[10,1]{1,0} or(shift-left.163, shift-right-logical.166)
  xor.168 = u32[10,1]{1,0} xor(add.161, or.167)
  add.171 = u32[10,1]{1,0} add(add.161, xor.168)
  slice.169 = u32[1]{0} slice(get-tuple-element.154), slice={[1:2]}
  reshape.170 = u32[] reshape(slice.169)
  broadcast.172 = u32[10,1]{1,0} broadcast(reshape.170), dimensions={}
  shift-left.173 = u32[10,1]{1,0} shift-left(xor.168, broadcast.172)
  subtract.174 = u32[] subtract(constant.156, reshape.170)
  broadcast.175 = u32[10,1]{1,0} broadcast(subtract.174), dimensions={}
  shift-right-logical.176 = u32[10,1]{1,0} shift-right-logical(xor.168, broadcast.175)
  or.177 = u32[10,1]{1,0} or(shift-left.173, shift-right-logical.176)
  xor.178 = u32[10,1]{1,0} xor(add.171, or.177)
  add.181 = u32[10,1]{1,0} add(add.171, xor.178)
  slice.179 = u32[1]{0} slice(get-tuple-element.154), slice={[2:3]}
  reshape.180 = u32[] reshape(slice.179)
  broadcast.182 = u32[10,1]{1,0} broadcast(reshape.180), dimensions={}
  shift-left.183 = u32[10,1]{1,0} shift-left(xor.178, broadcast.182)
  subtract.184 = u32[] subtract(constant.156, reshape.180)
  broadcast.185 = u32[10,1]{1,0} broadcast(subtract.184), dimensions={}
  shift-right-logical.186 = u32[10,1]{1,0} shift-right-logical(xor.178, broadcast.185)
  or.187 = u32[10,1]{1,0} or(shift-left.183, shift-right-logical.186)
  xor.188 = u32[10,1]{1,0} xor(add.181, or.187)
  add.191 = u32[10,1]{1,0} add(add.181, xor.188)
  get-tuple-element.151 = u32[10,1]{1,0} get-tuple-element(arg_tuple.146), index=4
  add.199 = u32[10,1]{1,0} add(add.191, get-tuple-element.151)
  slice.189 = u32[1]{0} slice(get-tuple-element.154), slice={[3:4]}
  reshape.190 = u32[] reshape(slice.189)
  broadcast.192 = u32[10,1]{1,0} broadcast(reshape.190), dimensions={}
  shift-left.193 = u32[10,1]{1,0} shift-left(xor.188, broadcast.192)
  subtract.194 = u32[] subtract(constant.156, reshape.190)
  broadcast.195 = u32[10,1]{1,0} broadcast(subtract.194), dimensions={}
  shift-right-logical.196 = u32[10,1]{1,0} shift-right-logical(xor.188, broadcast.195)
  or.197 = u32[10,1]{1,0} or(shift-left.193, shift-right-logical.196)
  xor.198 = u32[10,1]{1,0} xor(add.191, or.197)
  get-tuple-element.152 = u32[10,1]{1,0} get-tuple-element(arg_tuple.146), index=5
  add.200 = u32[10,1]{1,0} add(xor.198, get-tuple-element.152)
  add.201 = s32[] add(get-tuple-element.148, constant.157)
  convert.202 = u32[] convert(add.201)
  broadcast.203 = u32[10,1]{1,0} broadcast(convert.202), dimensions={}
  add.204 = u32[10,1]{1,0} add(add.200, broadcast.203)
  get-tuple-element.153 = u32[10,1]{1,0} get-tuple-element(arg_tuple.146), index=6
  get-tuple-element.155 = u32[4]{0} get-tuple-element(arg_tuple.146), index=8
  ROOT tuple.206 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(add.205, add.158, add.199, add.204, get-tuple-element.152, get-tuple-element.153, get-tuple-element.151, get-tuple-element.155, get-tuple-element.154)
}

region_3.207 {
  arg_tuple.208 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
  get-tuple-element.210 = s32[] get-tuple-element(arg_tuple.208), index=1
  get-tuple-element.211 = u32[10,1]{1,0} get-tuple-element(arg_tuple.208), index=2
  get-tuple-element.212 = u32[10,1]{1,0} get-tuple-element(arg_tuple.208), index=3
  get-tuple-element.213 = u32[10,1]{1,0} get-tuple-element(arg_tuple.208), index=4
  get-tuple-element.214 = u32[10,1]{1,0} get-tuple-element(arg_tuple.208), index=5
  get-tuple-element.215 = u32[10,1]{1,0} get-tuple-element(arg_tuple.208), index=6
  get-tuple-element.216 = u32[4]{0} get-tuple-element(arg_tuple.208), index=7
  get-tuple-element.217 = u32[4]{0} get-tuple-element(arg_tuple.208), index=8
  get-tuple-element.209 = s32[] get-tuple-element(arg_tuple.208), index=0
  constant.218 = s32[] constant(5)
  ROOT compare.219 = pred[] compare(get-tuple-element.209, constant.218), direction=LT
}

region_4.220 {
  arg_tuple.221 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
  get-tuple-element.222 = s32[] get-tuple-element(arg_tuple.221), index=0
  constant.232 = s32[] constant(1)
  add.280 = s32[] add(get-tuple-element.222, constant.232)
  get-tuple-element.223 = s32[] get-tuple-element(arg_tuple.221), index=1
  add.233 = s32[] add(get-tuple-element.223, constant.232)
  get-tuple-element.224 = u32[10,1]{1,0} get-tuple-element(arg_tuple.221), index=2
  get-tuple-element.225 = u32[10,1]{1,0} get-tuple-element(arg_tuple.221), index=3
  add.236 = u32[10,1]{1,0} add(get-tuple-element.224, get-tuple-element.225)
  get-tuple-element.229 = u32[4]{0} get-tuple-element(arg_tuple.221), index=7
  slice.234 = u32[1]{0} slice(get-tuple-element.229), slice={[0:1]}
  reshape.235 = u32[] reshape(slice.234)
  broadcast.237 = u32[10,1]{1,0} broadcast(reshape.235), dimensions={}
  shift-left.238 = u32[10,1]{1,0} shift-left(get-tuple-element.225, broadcast.237)
  constant.231 = u32[] constant(32)
  subtract.239 = u32[] subtract(constant.231, reshape.235)
  broadcast.240 = u32[10,1]{1,0} broadcast(subtract.239), dimensions={}
  shift-right-logical.241 = u32[10,1]{1,0} shift-right-logical(get-tuple-element.225, broadcast.240)
  or.242 = u32[10,1]{1,0} or(shift-left.238, shift-right-logical.241)
  xor.243 = u32[10,1]{1,0} xor(add.236, or.242)
  add.246 = u32[10,1]{1,0} add(add.236, xor.243)
  slice.244 = u32[1]{0} slice(get-tuple-element.229), slice={[1:2]}
  reshape.245 = u32[] reshape(slice.244)
  broadcast.247 = u32[10,1]{1,0} broadcast(reshape.245), dimensions={}
  shift-left.248 = u32[10,1]{1,0} shift-left(xor.243, broadcast.247)
  subtract.249 = u32[] subtract(constant.231, reshape.245)
  broadcast.250 = u32[10,1]{1,0} broadcast(subtract.249), dimensions={}
  shift-right-logical.251 = u32[10,1]{1,0} shift-right-logical(xor.243, broadcast.250)
  or.252 = u32[10,1]{1,0} or(shift-left.248, shift-right-logical.251)
  xor.253 = u32[10,1]{1,0} xor(add.246, or.252)
  add.256 = u32[10,1]{1,0} add(add.246, xor.253)
  slice.254 = u32[1]{0} slice(get-tuple-element.229), slice={[2:3]}
  reshape.255 = u32[] reshape(slice.254)
  broadcast.257 = u32[10,1]{1,0} broadcast(reshape.255), dimensions={}
  shift-left.258 = u32[10,1]{1,0} shift-left(xor.253, broadcast.257)
  subtract.259 = u32[] subtract(constant.231, reshape.255)
  broadcast.260 = u32[10,1]{1,0} broadcast(subtract.259), dimensions={}
  shift-right-logical.261 = u32[10,1]{1,0} shift-right-logical(xor.253, broadcast.260)
  or.262 = u32[10,1]{1,0} or(shift-left.258, shift-right-logical.261)
  xor.263 = u32[10,1]{1,0} xor(add.256, or.262)
  add.266 = u32[10,1]{1,0} add(add.256, xor.263)
  get-tuple-element.226 = u32[10,1]{1,0} get-tuple-element(arg_tuple.221), index=4
  add.274 = u32[10,1]{1,0} add(add.266, get-tuple-element.226)
  slice.264 = u32[1]{0} slice(get-tuple-element.229), slice={[3:4]}
  reshape.265 = u32[] reshape(slice.264)
  broadcast.267 = u32[10,1]{1,0} broadcast(reshape.265), dimensions={}
  shift-left.268 = u32[10,1]{1,0} shift-left(xor.263, broadcast.267)
  subtract.269 = u32[] subtract(constant.231, reshape.265)
  broadcast.270 = u32[10,1]{1,0} broadcast(subtract.269), dimensions={}
  shift-right-logical.271 = u32[10,1]{1,0} shift-right-logical(xor.263, broadcast.270)
  or.272 = u32[10,1]{1,0} or(shift-left.268, shift-right-logical.271)
  xor.273 = u32[10,1]{1,0} xor(add.266, or.272)
  get-tuple-element.227 = u32[10,1]{1,0} get-tuple-element(arg_tuple.221), index=5
  add.275 = u32[10,1]{1,0} add(xor.273, get-tuple-element.227)
  add.276 = s32[] add(get-tuple-element.223, constant.232)
  convert.277 = u32[] convert(add.276)
  broadcast.278 = u32[10,1]{1,0} broadcast(convert.277), dimensions={}
  add.279 = u32[10,1]{1,0} add(add.275, broadcast.278)
  get-tuple-element.228 = u32[10,1]{1,0} get-tuple-element(arg_tuple.221), index=6
  get-tuple-element.230 = u32[4]{0} get-tuple-element(arg_tuple.221), index=8
  ROOT tuple.281 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(add.280, add.233, add.274, add.279, get-tuple-element.227, get-tuple-element.228, get-tuple-element.226, get-tuple-element.230, get-tuple-element.229)
}

region_5.282 {
  arg_tuple.283 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
  get-tuple-element.285 = s32[] get-tuple-element(arg_tuple.283), index=1
  get-tuple-element.286 = u32[10,1]{1,0} get-tuple-element(arg_tuple.283), index=2
  get-tuple-element.287 = u32[10,1]{1,0} get-tuple-element(arg_tuple.283), index=3
  get-tuple-element.288 = u32[10,1]{1,0} get-tuple-element(arg_tuple.283), index=4
  get-tuple-element.289 = u32[10,1]{1,0} get-tuple-element(arg_tuple.283), index=5
  get-tuple-element.290 = u32[10,1]{1,0} get-tuple-element(arg_tuple.283), index=6
  get-tuple-element.291 = u32[4]{0} get-tuple-element(arg_tuple.283), index=7
  get-tuple-element.292 = u32[4]{0} get-tuple-element(arg_tuple.283), index=8
  get-tuple-element.284 = s32[] get-tuple-element(arg_tuple.283), index=0
  constant.293 = s32[] constant(5)
  ROOT compare.294 = pred[] compare(get-tuple-element.284, constant.293), direction=LT
}

_randint.295 {
  constant.303 = s32[] constant(0)
  Arg_0.296 = u32[10,2]{1,0} parameter(0)
  call.312 = u32[10,2,2]{2,1,0} call(Arg_0.296), to_apply=_threefry_split.102
  slice.313 = u32[10,1,2]{2,1,0} slice(call.312), slice={[0:10], [0:1], [0:2]}
  reshape.314 = u32[10,2]{1,0} reshape(slice.313)
  slice.317 = u32[10,1]{1,0} slice(reshape.314), slice={[0:10], [0:1]}
  slice.318 = u32[10,1]{1,0} slice(reshape.314), slice={[0:10], [1:2]}
  xor.319 = u32[10,1]{1,0} xor(slice.317, slice.318)
  constant.299 = u32[] constant(466688986)
  broadcast.300 = u32[10,1]{1,0} broadcast(constant.299), dimensions={}
  xor.320 = u32[10,1]{1,0} xor(xor.319, broadcast.300)
  constant.305 = u32[4]{0} constant({13, 15, 26, 6})
  constant.304 = u32[4]{0} constant({17, 29, 16, 24})
  tuple.321 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(constant.303, constant.303, slice.317, slice.318, slice.318, xor.320, slice.317, constant.305, constant.304)
  while.322 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) while(tuple.321), condition=region_3.207, body=region_2.145
  get-tuple-element.323 = s32[] get-tuple-element(while.322), index=0
  get-tuple-element.324 = s32[] get-tuple-element(while.322), index=1
  get-tuple-element.326 = u32[10,1]{1,0} get-tuple-element(while.322), index=3
  get-tuple-element.327 = u32[10,1]{1,0} get-tuple-element(while.322), index=4
  get-tuple-element.328 = u32[10,1]{1,0} get-tuple-element(while.322), index=5
  get-tuple-element.329 = u32[10,1]{1,0} get-tuple-element(while.322), index=6
  get-tuple-element.330 = u32[4]{0} get-tuple-element(while.322), index=7
  get-tuple-element.331 = u32[4]{0} get-tuple-element(while.322), index=8
  slice.315 = u32[10,1,2]{2,1,0} slice(call.312), slice={[0:10], [1:2], [0:2]}
  reshape.316 = u32[10,2]{1,0} reshape(slice.315)
  slice.333 = u32[10,1]{1,0} slice(reshape.316), slice={[0:10], [0:1]}
  slice.334 = u32[10,1]{1,0} slice(reshape.316), slice={[0:10], [1:2]}
  xor.335 = u32[10,1]{1,0} xor(slice.333, slice.334)
  xor.336 = u32[10,1]{1,0} xor(xor.335, broadcast.300)
  tuple.337 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(constant.303, constant.303, slice.333, slice.334, slice.334, xor.336, slice.333, constant.305, constant.304)
  while.338 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) while(tuple.337), condition=region_5.282, body=region_4.220
  get-tuple-element.339 = s32[] get-tuple-element(while.338), index=0
  get-tuple-element.340 = s32[] get-tuple-element(while.338), index=1
  get-tuple-element.342 = u32[10,1]{1,0} get-tuple-element(while.338), index=3
  get-tuple-element.343 = u32[10,1]{1,0} get-tuple-element(while.338), index=4
  get-tuple-element.344 = u32[10,1]{1,0} get-tuple-element(while.338), index=5
  get-tuple-element.345 = u32[10,1]{1,0} get-tuple-element(while.338), index=6
  get-tuple-element.346 = u32[4]{0} get-tuple-element(while.338), index=7
  get-tuple-element.347 = u32[4]{0} get-tuple-element(while.338), index=8
  Arg_1.297 = s32[] parameter(1)
  constant.307 = s32[] constant(-2147483648)
  constant.306 = s32[] constant(2147483647)
  call.310 = s32[] call(Arg_1.297, constant.307, constant.306), to_apply=clip_0.9
  broadcast.370 = s32[10]{0} broadcast(call.310), dimensions={}
  get-tuple-element.325 = u32[10,1]{1,0} get-tuple-element(while.322), index=2
  reshape.332 = u32[10]{0} reshape(get-tuple-element.325)
  Arg_2.298 = s32[] parameter(2)
  call.308 = s32[] call(constant.306, constant.307, constant.306), to_apply=clip.3
  compare.309 = pred[] compare(Arg_2.298, call.308), direction=GT
  call.311 = s32[] call(Arg_2.298, constant.307, constant.306), to_apply=clip_0.15
  compare.353 = pred[] compare(call.311, call.310), direction=GT
  and.354 = pred[] and(compare.309, compare.353)
  compare.351 = pred[] compare(call.311, call.310), direction=LE
  constant.302 = u32[] constant(1)
  subtract.349 = s32[] subtract(call.311, call.310)
  convert.350 = u32[] convert(subtract.349)
  select.352 = u32[] select(compare.351, constant.302, convert.350)
  add.355 = u32[] add(select.352, constant.302)
  select.356 = u32[] select(and.354, add.355, select.352)
  broadcast.360 = u32[10]{0} broadcast(select.356), dimensions={}
  remainder.361 = u32[10]{0} remainder(reshape.332, broadcast.360)
  constant.301 = u32[] constant(65536)
  remainder.357 = u32[] remainder(constant.301, select.356)
  multiply.358 = u32[] multiply(remainder.357, remainder.357)
  remainder.359 = u32[] remainder(multiply.358, select.356)
  broadcast.362 = u32[10]{0} broadcast(remainder.359), dimensions={}
  multiply.363 = u32[10]{0} multiply(remainder.361, broadcast.362)
  get-tuple-element.341 = u32[10,1]{1,0} get-tuple-element(while.338), index=2
  reshape.348 = u32[10]{0} reshape(get-tuple-element.341)
  broadcast.364 = u32[10]{0} broadcast(select.356), dimensions={}
  remainder.365 = u32[10]{0} remainder(reshape.348, broadcast.364)
  add.366 = u32[10]{0} add(multiply.363, remainder.365)
  broadcast.367 = u32[10]{0} broadcast(select.356), dimensions={}
  remainder.368 = u32[10]{0} remainder(add.366, broadcast.367)
  convert.369 = s32[10]{0} convert(remainder.368)
  ROOT add.371 = s32[10]{0} add(broadcast.370, convert.369)
}

region_7.372 {
  arg_tuple.373 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
  get-tuple-element.374 = s32[] get-tuple-element(arg_tuple.373), index=0
  constant.384 = s32[] constant(1)
  add.432 = s32[] add(get-tuple-element.374, constant.384)
  get-tuple-element.375 = s32[] get-tuple-element(arg_tuple.373), index=1
  add.385 = s32[] add(get-tuple-element.375, constant.384)
  get-tuple-element.376 = u32[10,1]{1,0} get-tuple-element(arg_tuple.373), index=2
  get-tuple-element.377 = u32[10,1]{1,0} get-tuple-element(arg_tuple.373), index=3
  add.388 = u32[10,1]{1,0} add(get-tuple-element.376, get-tuple-element.377)
  get-tuple-element.381 = u32[4]{0} get-tuple-element(arg_tuple.373), index=7
  slice.386 = u32[1]{0} slice(get-tuple-element.381), slice={[0:1]}
  reshape.387 = u32[] reshape(slice.386)
  broadcast.389 = u32[10,1]{1,0} broadcast(reshape.387), dimensions={}
  shift-left.390 = u32[10,1]{1,0} shift-left(get-tuple-element.377, broadcast.389)
  constant.383 = u32[] constant(32)
  subtract.391 = u32[] subtract(constant.383, reshape.387)
  broadcast.392 = u32[10,1]{1,0} broadcast(subtract.391), dimensions={}
  shift-right-logical.393 = u32[10,1]{1,0} shift-right-logical(get-tuple-element.377, broadcast.392)
  or.394 = u32[10,1]{1,0} or(shift-left.390, shift-right-logical.393)
  xor.395 = u32[10,1]{1,0} xor(add.388, or.394)
  add.398 = u32[10,1]{1,0} add(add.388, xor.395)
  slice.396 = u32[1]{0} slice(get-tuple-element.381), slice={[1:2]}
  reshape.397 = u32[] reshape(slice.396)
  broadcast.399 = u32[10,1]{1,0} broadcast(reshape.397), dimensions={}
  shift-left.400 = u32[10,1]{1,0} shift-left(xor.395, broadcast.399)
  subtract.401 = u32[] subtract(constant.383, reshape.397)
  broadcast.402 = u32[10,1]{1,0} broadcast(subtract.401), dimensions={}
  shift-right-logical.403 = u32[10,1]{1,0} shift-right-logical(xor.395, broadcast.402)
  or.404 = u32[10,1]{1,0} or(shift-left.400, shift-right-logical.403)
  xor.405 = u32[10,1]{1,0} xor(add.398, or.404)
  add.408 = u32[10,1]{1,0} add(add.398, xor.405)
  slice.406 = u32[1]{0} slice(get-tuple-element.381), slice={[2:3]}
  reshape.407 = u32[] reshape(slice.406)
  broadcast.409 = u32[10,1]{1,0} broadcast(reshape.407), dimensions={}
  shift-left.410 = u32[10,1]{1,0} shift-left(xor.405, broadcast.409)
  subtract.411 = u32[] subtract(constant.383, reshape.407)
  broadcast.412 = u32[10,1]{1,0} broadcast(subtract.411), dimensions={}
  shift-right-logical.413 = u32[10,1]{1,0} shift-right-logical(xor.405, broadcast.412)
  or.414 = u32[10,1]{1,0} or(shift-left.410, shift-right-logical.413)
  xor.415 = u32[10,1]{1,0} xor(add.408, or.414)
  add.418 = u32[10,1]{1,0} add(add.408, xor.415)
  get-tuple-element.378 = u32[10,1]{1,0} get-tuple-element(arg_tuple.373), index=4
  add.426 = u32[10,1]{1,0} add(add.418, get-tuple-element.378)
  slice.416 = u32[1]{0} slice(get-tuple-element.381), slice={[3:4]}
  reshape.417 = u32[] reshape(slice.416)
  broadcast.419 = u32[10,1]{1,0} broadcast(reshape.417), dimensions={}
  shift-left.420 = u32[10,1]{1,0} shift-left(xor.415, broadcast.419)
  subtract.421 = u32[] subtract(constant.383, reshape.417)
  broadcast.422 = u32[10,1]{1,0} broadcast(subtract.421), dimensions={}
  shift-right-logical.423 = u32[10,1]{1,0} shift-right-logical(xor.415, broadcast.422)
  or.424 = u32[10,1]{1,0} or(shift-left.420, shift-right-logical.423)
  xor.425 = u32[10,1]{1,0} xor(add.418, or.424)
  get-tuple-element.379 = u32[10,1]{1,0} get-tuple-element(arg_tuple.373), index=5
  add.427 = u32[10,1]{1,0} add(xor.425, get-tuple-element.379)
  add.428 = s32[] add(get-tuple-element.375, constant.384)
  convert.429 = u32[] convert(add.428)
  broadcast.430 = u32[10,1]{1,0} broadcast(convert.429), dimensions={}
  add.431 = u32[10,1]{1,0} add(add.427, broadcast.430)
  get-tuple-element.380 = u32[10,1]{1,0} get-tuple-element(arg_tuple.373), index=6
  get-tuple-element.382 = u32[4]{0} get-tuple-element(arg_tuple.373), index=8
  ROOT tuple.433 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(add.432, add.385, add.426, add.431, get-tuple-element.379, get-tuple-element.380, get-tuple-element.378, get-tuple-element.382, get-tuple-element.381)
}

region_8.434 {
  arg_tuple.435 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
  get-tuple-element.437 = s32[] get-tuple-element(arg_tuple.435), index=1
  get-tuple-element.438 = u32[10,1]{1,0} get-tuple-element(arg_tuple.435), index=2
  get-tuple-element.439 = u32[10,1]{1,0} get-tuple-element(arg_tuple.435), index=3
  get-tuple-element.440 = u32[10,1]{1,0} get-tuple-element(arg_tuple.435), index=4
  get-tuple-element.441 = u32[10,1]{1,0} get-tuple-element(arg_tuple.435), index=5
  get-tuple-element.442 = u32[10,1]{1,0} get-tuple-element(arg_tuple.435), index=6
  get-tuple-element.443 = u32[4]{0} get-tuple-element(arg_tuple.435), index=7
  get-tuple-element.444 = u32[4]{0} get-tuple-element(arg_tuple.435), index=8
  get-tuple-element.436 = s32[] get-tuple-element(arg_tuple.435), index=0
  constant.445 = s32[] constant(5)
  ROOT compare.446 = pred[] compare(get-tuple-element.436, constant.445), direction=LT
}

_uniform.447 {
  constant.459 = s32[] constant(0)
  Arg_0.448 = u32[10,2]{1,0} parameter(0)
  slice.464 = u32[10,1]{1,0} slice(Arg_0.448), slice={[0:10], [0:1]}
  slice.465 = u32[10,1]{1,0} slice(Arg_0.448), slice={[0:10], [1:2]}
  xor.466 = u32[10,1]{1,0} xor(slice.464, slice.465)
  constant.457 = u32[] constant(466688986)
  broadcast.458 = u32[10,1]{1,0} broadcast(constant.457), dimensions={}
  xor.467 = u32[10,1]{1,0} xor(xor.466, broadcast.458)
  constant.461 = u32[4]{0} constant({13, 15, 26, 6})
  constant.460 = u32[4]{0} constant({17, 29, 16, 24})
  tuple.468 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(constant.459, constant.459, slice.464, slice.465, slice.465, xor.467, slice.464, constant.461, constant.460)
  while.469 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) while(tuple.468), condition=region_8.434, body=region_7.372
  get-tuple-element.470 = s32[] get-tuple-element(while.469), index=0
  get-tuple-element.471 = s32[] get-tuple-element(while.469), index=1
  get-tuple-element.473 = u32[10,1]{1,0} get-tuple-element(while.469), index=3
  get-tuple-element.474 = u32[10,1]{1,0} get-tuple-element(while.469), index=4
  get-tuple-element.475 = u32[10,1]{1,0} get-tuple-element(while.469), index=5
  get-tuple-element.476 = u32[10,1]{1,0} get-tuple-element(while.469), index=6
  get-tuple-element.477 = u32[4]{0} get-tuple-element(while.469), index=7
  get-tuple-element.478 = u32[4]{0} get-tuple-element(while.469), index=8
  Arg_1.449 = f32[] parameter(1)
  reshape.494 = f32[1,1]{1,0} reshape(Arg_1.449)
  broadcast.495 = f32[1,1]{1,0} broadcast(reshape.494), dimensions={0,1}
  reshape.496 = f32[1]{0} reshape(broadcast.495)
  broadcast.497 = f32[10,1]{1,0} broadcast(reshape.496), dimensions={1}
  get-tuple-element.472 = u32[10,1]{1,0} get-tuple-element(while.469), index=2
  constant.455 = u32[] constant(9)
  broadcast.456 = u32[10,1]{1,0} broadcast(constant.455), dimensions={}
  shift-right-logical.479 = u32[10,1]{1,0} shift-right-logical(get-tuple-element.472, broadcast.456)
  constant.453 = u32[] constant(1065353216)
  broadcast.454 = u32[10,1]{1,0} broadcast(constant.453), dimensions={}
  or.480 = u32[10,1]{1,0} or(shift-right-logical.479, broadcast.454)
  bitcast-convert.481 = f32[10,1]{1,0} bitcast-convert(or.480)
  constant.451 = f32[] constant(1)
  broadcast.452 = f32[10,1]{1,0} broadcast(constant.451), dimensions={}
  subtract.482 = f32[10,1]{1,0} subtract(bitcast-convert.481, broadcast.452)
  Arg_2.450 = f32[] parameter(2)
  reshape.463 = f32[1]{0} reshape(Arg_2.450)
  reshape.462 = f32[1]{0} reshape(Arg_1.449)
  subtract.483 = f32[1]{0} subtract(reshape.463, reshape.462)
  reshape.484 = f32[1,1]{1,0} reshape(subtract.483)
  broadcast.485 = f32[1,1]{1,0} broadcast(reshape.484), dimensions={0,1}
  reshape.486 = f32[1]{0} reshape(broadcast.485)
  broadcast.487 = f32[10,1]{1,0} broadcast(reshape.486), dimensions={1}
  multiply.488 = f32[10,1]{1,0} multiply(subtract.482, broadcast.487)
  reshape.489 = f32[1,1]{1,0} reshape(Arg_1.449)
  broadcast.490 = f32[1,1]{1,0} broadcast(reshape.489), dimensions={0,1}
  reshape.491 = f32[1]{0} reshape(broadcast.490)
  broadcast.492 = f32[10,1]{1,0} broadcast(reshape.491), dimensions={1}
  add.493 = f32[10,1]{1,0} add(multiply.488, broadcast.492)
  ROOT maximum.498 = f32[10,1]{1,0} maximum(broadcast.497, add.493)
}

region_9.499 {
  arg_tuple.500 = (f32[10]{0}, f32[10,1]{1,0}) parameter(0)
  get-tuple-element.502 = f32[10,1]{1,0} get-tuple-element(arg_tuple.500), index=1
  reshape.511 = f32[10]{0} reshape(get-tuple-element.502)
  constant.503 = f32[] constant(2)
  broadcast.504 = f32[10]{0} broadcast(constant.503), dimensions={}
  compare.512 = pred[10]{0} compare(reshape.511, broadcast.504), direction=LT
  get-tuple-element.501 = f32[10]{0} get-tuple-element(arg_tuple.500), index=0
  constant.507 = f32[] constant(0.1)
  broadcast.508 = f32[10]{0} broadcast(constant.507), dimensions={}
  add.509 = f32[10]{0} add(get-tuple-element.501, broadcast.508)
  select.513 = f32[10]{0} select(compare.512, add.509, get-tuple-element.501)
  reshape.514 = pred[10,1]{1,0} reshape(compare.512)
  constant.505 = f32[] constant(0.1)
  broadcast.506 = f32[10,1]{1,0} broadcast(constant.505), dimensions={}
  add.510 = f32[10,1]{1,0} add(get-tuple-element.502, broadcast.506)
  select.515 = f32[10,1]{1,0} select(reshape.514, add.510, get-tuple-element.502)
  ROOT tuple.516 = (f32[10]{0}, f32[10,1]{1,0}) tuple(select.513, select.515)
}

region_11.517 {
  Arg_0.518 = pred[] parameter(0)
  Arg_1.519 = pred[] parameter(1)
  ROOT or.520 = pred[] or(Arg_0.518, Arg_1.519)
}

region_10.521 {
  arg_tuple.522 = (f32[10]{0}, f32[10,1]{1,0}) parameter(0)
  get-tuple-element.523 = f32[10]{0} get-tuple-element(arg_tuple.522), index=0
  get-tuple-element.524 = f32[10,1]{1,0} get-tuple-element(arg_tuple.522), index=1
  reshape.528 = f32[10]{0} reshape(get-tuple-element.524)
  constant.526 = f32[] constant(2)
  broadcast.527 = f32[10]{0} broadcast(constant.526), dimensions={}
  compare.529 = pred[10]{0} compare(reshape.528, broadcast.527), direction=LT
  constant.525 = pred[] constant(false)
  ROOT reduce.530 = pred[] reduce(compare.529, constant.525), dimensions={0}, to_apply=region_11.517
}

region_6.531 {
  arg_tuple.532 = (f32[10]{0}, f32[10,1]{1,0}, s32[10]{0}) parameter(0)
  get-tuple-element.535 = s32[10]{0} get-tuple-element(arg_tuple.532), index=2
  constant.536 = s32[] constant(5)
  broadcast.537 = s32[10]{0} broadcast(constant.536), dimensions={}
  compare.556 = pred[10]{0} compare(get-tuple-element.535, broadcast.537), direction=LT
  get-tuple-element.533 = f32[10]{0} get-tuple-element(arg_tuple.532), index=0
  constant.542 = s32[] constant(32)
  broadcast.543 = s32[10]{0} broadcast(constant.542), dimensions={}
  shift-right-logical.544 = s32[10]{0} shift-right-logical(get-tuple-element.535, broadcast.543)
  convert.545 = u32[10]{0} convert(shift-right-logical.544)
  reshape.546 = u32[10,1]{1,0} reshape(convert.545)
  convert.547 = u32[10]{0} convert(get-tuple-element.535)
  reshape.548 = u32[10,1]{1,0} reshape(convert.547)
  concatenate.549 = u32[10,2]{1,0} concatenate(reshape.546, reshape.548), dimensions={1}
  constant.541 = f32[] constant(0)
  constant.540 = f32[] constant(1)
  call.550 = f32[10,1]{1,0} call(concatenate.549, constant.541, constant.540), to_apply=_uniform.447
  tuple.551 = (f32[10]{0}, f32[10,1]{1,0}) tuple(get-tuple-element.533, call.550)
  while.552 = (f32[10]{0}, f32[10,1]{1,0}) while(tuple.551), condition=region_10.521, body=region_9.499
  get-tuple-element.553 = f32[10]{0} get-tuple-element(while.552), index=0
  select.557 = f32[10]{0} select(compare.556, get-tuple-element.553, get-tuple-element.533)
  reshape.558 = pred[10,1]{1,0} reshape(compare.556)
  get-tuple-element.554 = f32[10,1]{1,0} get-tuple-element(while.552), index=1
  get-tuple-element.534 = f32[10,1]{1,0} get-tuple-element(arg_tuple.532), index=1
  select.559 = f32[10,1]{1,0} select(reshape.558, get-tuple-element.554, get-tuple-element.534)
  constant.538 = s32[] constant(1)
  broadcast.539 = s32[10]{0} broadcast(constant.538), dimensions={}
  add.555 = s32[10]{0} add(get-tuple-element.535, broadcast.539)
  select.560 = s32[10]{0} select(compare.556, add.555, get-tuple-element.535)
  ROOT tuple.561 = (f32[10]{0}, f32[10,1]{1,0}, s32[10]{0}) tuple(select.557, select.559, select.560)
}

region_13.562 {
  Arg_0.563 = pred[] parameter(0)
  Arg_1.564 = pred[] parameter(1)
  ROOT or.565 = pred[] or(Arg_0.563, Arg_1.564)
}

region_12.566 {
  arg_tuple.567 = (f32[10]{0}, f32[10,1]{1,0}, s32[10]{0}) parameter(0)
  get-tuple-element.568 = f32[10]{0} get-tuple-element(arg_tuple.567), index=0
  get-tuple-element.569 = f32[10,1]{1,0} get-tuple-element(arg_tuple.567), index=1
  get-tuple-element.570 = s32[10]{0} get-tuple-element(arg_tuple.567), index=2
  constant.572 = s32[] constant(5)
  broadcast.573 = s32[10]{0} broadcast(constant.572), dimensions={}
  compare.574 = pred[10]{0} compare(get-tuple-element.570, broadcast.573), direction=LT
  constant.571 = pred[] constant(false)
  ROOT reduce.575 = pred[] reduce(compare.574, constant.571), dimensions={0}, to_apply=region_13.562
}

solve.576 {
  constant.579 = f32[] constant(0)
  broadcast.580 = f32[10]{0} broadcast(constant.579), dimensions={}
  Arg_0.577 = f32[10,1]{1,0} parameter(0)
  Arg_1.578 = u32[10,2]{1,0} parameter(1)
  constant.581 = s32[] constant(-2)
  constant.582 = s32[] constant(2)
  call.583 = s32[10]{0} call(Arg_1.578, constant.581, constant.582), to_apply=_randint.295
  tuple.584 = (f32[10]{0}, f32[10,1]{1,0}, s32[10]{0}) tuple(broadcast.580, Arg_0.577, call.583)
  while.585 = (f32[10]{0}, f32[10,1]{1,0}, s32[10]{0}) while(tuple.584), condition=region_12.566, body=region_6.531
  get-tuple-element.586 = f32[10]{0} get-tuple-element(while.585), index=0
  ROOT get-tuple-element.587 = f32[10,1]{1,0} get-tuple-element(while.585), index=1
  get-tuple-element.588 = s32[10]{0} get-tuple-element(while.585), index=2
}

ENTRY main.591 {
  Arg_0.1 = f32[10,1]{1,0} parameter(0)
  Arg_1.2 = u32[10,2]{1,0} parameter(1)
  call.589 = f32[10,1]{1,0} call(Arg_0.1, Arg_1.2), to_apply=solve.576
  ROOT tuple.590 = (f32[10,1]{1,0}) tuple(call.589)
}
patrick-kidger commented 5 months ago

For performance-related things like this it is usually in XLA. JAX is mostly at the mercy of whatever code XLA generates.

Unfortunately the parallelism part of this isn't something I'm familiar with at all. I think @sharadmv might know more? This one is out of my wheelhouse I'm afraid.