Open lockwo opened 6 months ago
To be extra clear, if you make count_initial = 0
you see the same speed for sharding and for pmap. Only when this varies per shard does this slow down
Also, if you replace the body and do
new_t, new_y, theta = inner_loop_body((t, y, theta))
#new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
return (new_t, new_y, theta, count + 1)
you see the same speed, which indicates the second while loop is important to the slowdown
@patrick-kidger you mentioned in https://github.com/patrick-kidger/diffrax/issues/407 that you suspect this is within XLA, do you have any advice on how to approach that? I haven't investigated an XLA system this complex before. Even my reduced complexity example (shown below) yields XLA's that are not exceedingly readable (shown further below). Is there a goto issue/piece of XLA/jax documentation on identifying whether a bug is in jax vs XLA and how to spot it?
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
def solve(init, key):
def inner_loop_cond(state):
t, y, _ = state
return y.squeeze() < 2
def inner_loop_body(state):
t, y, theta = state
return (t + 0.1, y + 0.1, theta)
def outer_loop_cond(state):
_, _, _, count = state
return count < 5
def outer_loop_body(state):
t, y, theta, count = state
y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
return (new_t, new_y, theta, count + 1)
inner_while_loop = jax.lax.while_loop
outer_while_loop = jax.lax.while_loop
theta = 5.0
t_initial = 0.0
y_initial = init
count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
return final_state[1]
HloModule xla_computation_solve, entry_computation_layout={(f32[10,1]{1,0}, u32[10,2]{1,0})->(f32[10,1]{1,0})}
clip.3 {
Arg_2.6 = s32[] parameter(2)
Arg_1.5 = s32[] parameter(1)
Arg_0.4 = s32[] parameter(0)
maximum.7 = s32[] maximum(Arg_1.5, Arg_0.4)
ROOT minimum.8 = s32[] minimum(Arg_2.6, maximum.7)
}
clip_0.9 {
Arg_2.12 = s32[] parameter(2)
Arg_1.11 = s32[] parameter(1)
Arg_0.10 = s32[] parameter(0)
maximum.13 = s32[] maximum(Arg_1.11, Arg_0.10)
ROOT minimum.14 = s32[] minimum(Arg_2.12, maximum.13)
}
clip_0.15 {
Arg_2.18 = s32[] parameter(2)
Arg_1.17 = s32[] parameter(1)
Arg_0.16 = s32[] parameter(0)
maximum.19 = s32[] maximum(Arg_1.17, Arg_0.16)
ROOT minimum.20 = s32[] minimum(Arg_2.18, maximum.19)
}
region_0.21 {
arg_tuple.22 = (s32[], s32[], u32[10,2]{1,0}, u32[10,2]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
get-tuple-element.23 = s32[] get-tuple-element(arg_tuple.22), index=0
constant.33 = s32[] constant(1)
add.87 = s32[] add(get-tuple-element.23, constant.33)
get-tuple-element.24 = s32[] get-tuple-element(arg_tuple.22), index=1
add.34 = s32[] add(get-tuple-element.24, constant.33)
get-tuple-element.25 = u32[10,2]{1,0} get-tuple-element(arg_tuple.22), index=2
get-tuple-element.26 = u32[10,2]{1,0} get-tuple-element(arg_tuple.22), index=3
add.37 = u32[10,2]{1,0} add(get-tuple-element.25, get-tuple-element.26)
get-tuple-element.30 = u32[4]{0} get-tuple-element(arg_tuple.22), index=7
slice.35 = u32[1]{0} slice(get-tuple-element.30), slice={[0:1]}
reshape.36 = u32[] reshape(slice.35)
broadcast.38 = u32[10,2]{1,0} broadcast(reshape.36), dimensions={}
shift-left.39 = u32[10,2]{1,0} shift-left(get-tuple-element.26, broadcast.38)
constant.32 = u32[] constant(32)
subtract.40 = u32[] subtract(constant.32, reshape.36)
broadcast.41 = u32[10,2]{1,0} broadcast(subtract.40), dimensions={}
shift-right-logical.42 = u32[10,2]{1,0} shift-right-logical(get-tuple-element.26, broadcast.41)
or.43 = u32[10,2]{1,0} or(shift-left.39, shift-right-logical.42)
xor.44 = u32[10,2]{1,0} xor(add.37, or.43)
add.47 = u32[10,2]{1,0} add(add.37, xor.44)
slice.45 = u32[1]{0} slice(get-tuple-element.30), slice={[1:2]}
reshape.46 = u32[] reshape(slice.45)
broadcast.48 = u32[10,2]{1,0} broadcast(reshape.46), dimensions={}
shift-left.49 = u32[10,2]{1,0} shift-left(xor.44, broadcast.48)
subtract.50 = u32[] subtract(constant.32, reshape.46)
broadcast.51 = u32[10,2]{1,0} broadcast(subtract.50), dimensions={}
shift-right-logical.52 = u32[10,2]{1,0} shift-right-logical(xor.44, broadcast.51)
or.53 = u32[10,2]{1,0} or(shift-left.49, shift-right-logical.52)
xor.54 = u32[10,2]{1,0} xor(add.47, or.53)
add.57 = u32[10,2]{1,0} add(add.47, xor.54)
slice.55 = u32[1]{0} slice(get-tuple-element.30), slice={[2:3]}
reshape.56 = u32[] reshape(slice.55)
broadcast.58 = u32[10,2]{1,0} broadcast(reshape.56), dimensions={}
shift-left.59 = u32[10,2]{1,0} shift-left(xor.54, broadcast.58)
subtract.60 = u32[] subtract(constant.32, reshape.56)
broadcast.61 = u32[10,2]{1,0} broadcast(subtract.60), dimensions={}
shift-right-logical.62 = u32[10,2]{1,0} shift-right-logical(xor.54, broadcast.61)
or.63 = u32[10,2]{1,0} or(shift-left.59, shift-right-logical.62)
xor.64 = u32[10,2]{1,0} xor(add.57, or.63)
add.67 = u32[10,2]{1,0} add(add.57, xor.64)
get-tuple-element.27 = u32[10,1]{1,0} get-tuple-element(arg_tuple.22), index=4
broadcast.75 = u32[10,1]{1,0} broadcast(get-tuple-element.27), dimensions={0,1}
reshape.76 = u32[10]{0} reshape(broadcast.75)
broadcast.77 = u32[10,2]{1,0} broadcast(reshape.76), dimensions={0}
add.78 = u32[10,2]{1,0} add(add.67, broadcast.77)
slice.65 = u32[1]{0} slice(get-tuple-element.30), slice={[3:4]}
reshape.66 = u32[] reshape(slice.65)
broadcast.68 = u32[10,2]{1,0} broadcast(reshape.66), dimensions={}
shift-left.69 = u32[10,2]{1,0} shift-left(xor.64, broadcast.68)
subtract.70 = u32[] subtract(constant.32, reshape.66)
broadcast.71 = u32[10,2]{1,0} broadcast(subtract.70), dimensions={}
shift-right-logical.72 = u32[10,2]{1,0} shift-right-logical(xor.64, broadcast.71)
or.73 = u32[10,2]{1,0} or(shift-left.69, shift-right-logical.72)
xor.74 = u32[10,2]{1,0} xor(add.67, or.73)
get-tuple-element.28 = u32[10,1]{1,0} get-tuple-element(arg_tuple.22), index=5
broadcast.79 = u32[10,1]{1,0} broadcast(get-tuple-element.28), dimensions={0,1}
reshape.80 = u32[10]{0} reshape(broadcast.79)
broadcast.81 = u32[10,2]{1,0} broadcast(reshape.80), dimensions={0}
add.82 = u32[10,2]{1,0} add(xor.74, broadcast.81)
add.83 = s32[] add(get-tuple-element.24, constant.33)
convert.84 = u32[] convert(add.83)
broadcast.85 = u32[10,2]{1,0} broadcast(convert.84), dimensions={}
add.86 = u32[10,2]{1,0} add(add.82, broadcast.85)
get-tuple-element.29 = u32[10,1]{1,0} get-tuple-element(arg_tuple.22), index=6
get-tuple-element.31 = u32[4]{0} get-tuple-element(arg_tuple.22), index=8
ROOT tuple.88 = (s32[], s32[], u32[10,2]{1,0}, u32[10,2]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(add.87, add.34, add.78, add.86, get-tuple-element.28, get-tuple-element.29, get-tuple-element.27, get-tuple-element.31, get-tuple-element.30)
}
region_1.89 {
arg_tuple.90 = (s32[], s32[], u32[10,2]{1,0}, u32[10,2]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
get-tuple-element.92 = s32[] get-tuple-element(arg_tuple.90), index=1
get-tuple-element.93 = u32[10,2]{1,0} get-tuple-element(arg_tuple.90), index=2
get-tuple-element.94 = u32[10,2]{1,0} get-tuple-element(arg_tuple.90), index=3
get-tuple-element.95 = u32[10,1]{1,0} get-tuple-element(arg_tuple.90), index=4
get-tuple-element.96 = u32[10,1]{1,0} get-tuple-element(arg_tuple.90), index=5
get-tuple-element.97 = u32[10,1]{1,0} get-tuple-element(arg_tuple.90), index=6
get-tuple-element.98 = u32[4]{0} get-tuple-element(arg_tuple.90), index=7
get-tuple-element.99 = u32[4]{0} get-tuple-element(arg_tuple.90), index=8
get-tuple-element.91 = s32[] get-tuple-element(arg_tuple.90), index=0
constant.100 = s32[] constant(5)
ROOT compare.101 = pred[] compare(get-tuple-element.91, constant.100), direction=LT
}
_threefry_split.102 {
constant.106 = s32[] constant(0)
iota.109 = u32[4]{0} iota(), iota_dimension=0
slice.112 = u32[2]{0} slice(iota.109), slice={[0:2]}
reshape.114 = u32[1,2]{1,0} reshape(slice.112)
broadcast.118 = u32[1,2]{1,0} broadcast(reshape.114), dimensions={0,1}
reshape.119 = u32[2]{0} reshape(broadcast.118)
broadcast.120 = u32[10,2]{1,0} broadcast(reshape.119), dimensions={1}
Arg_0.103 = u32[10,2]{1,0} parameter(0)
slice.110 = u32[10,1]{1,0} slice(Arg_0.103), slice={[0:10], [0:1]}
broadcast.121 = u32[10,1]{1,0} broadcast(slice.110), dimensions={0,1}
reshape.122 = u32[10]{0} reshape(broadcast.121)
broadcast.123 = u32[10,2]{1,0} broadcast(reshape.122), dimensions={0}
add.124 = u32[10,2]{1,0} add(broadcast.120, broadcast.123)
slice.113 = u32[2]{0} slice(iota.109), slice={[2:4]}
reshape.115 = u32[1,2]{1,0} reshape(slice.113)
broadcast.125 = u32[1,2]{1,0} broadcast(reshape.115), dimensions={0,1}
reshape.126 = u32[2]{0} reshape(broadcast.125)
broadcast.127 = u32[10,2]{1,0} broadcast(reshape.126), dimensions={1}
slice.111 = u32[10,1]{1,0} slice(Arg_0.103), slice={[0:10], [1:2]}
broadcast.128 = u32[10,1]{1,0} broadcast(slice.111), dimensions={0,1}
reshape.129 = u32[10]{0} reshape(broadcast.128)
broadcast.130 = u32[10,2]{1,0} broadcast(reshape.129), dimensions={0}
add.131 = u32[10,2]{1,0} add(broadcast.127, broadcast.130)
xor.116 = u32[10,1]{1,0} xor(slice.110, slice.111)
constant.104 = u32[] constant(466688986)
broadcast.105 = u32[10,1]{1,0} broadcast(constant.104), dimensions={}
xor.117 = u32[10,1]{1,0} xor(xor.116, broadcast.105)
constant.108 = u32[4]{0} constant({13, 15, 26, 6})
constant.107 = u32[4]{0} constant({17, 29, 16, 24})
tuple.132 = (s32[], s32[], u32[10,2]{1,0}, u32[10,2]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(constant.106, constant.106, add.124, add.131, slice.111, xor.117, slice.110, constant.108, constant.107)
while.133 = (s32[], s32[], u32[10,2]{1,0}, u32[10,2]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) while(tuple.132), condition=region_1.89, body=region_0.21
get-tuple-element.134 = s32[] get-tuple-element(while.133), index=0
get-tuple-element.135 = s32[] get-tuple-element(while.133), index=1
get-tuple-element.138 = u32[10,1]{1,0} get-tuple-element(while.133), index=4
get-tuple-element.139 = u32[10,1]{1,0} get-tuple-element(while.133), index=5
get-tuple-element.140 = u32[10,1]{1,0} get-tuple-element(while.133), index=6
get-tuple-element.141 = u32[4]{0} get-tuple-element(while.133), index=7
get-tuple-element.142 = u32[4]{0} get-tuple-element(while.133), index=8
get-tuple-element.136 = u32[10,2]{1,0} get-tuple-element(while.133), index=2
get-tuple-element.137 = u32[10,2]{1,0} get-tuple-element(while.133), index=3
concatenate.143 = u32[10,4]{1,0} concatenate(get-tuple-element.136, get-tuple-element.137), dimensions={1}
ROOT reshape.144 = u32[10,2,2]{2,1,0} reshape(concatenate.143)
}
region_2.145 {
arg_tuple.146 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
get-tuple-element.147 = s32[] get-tuple-element(arg_tuple.146), index=0
constant.157 = s32[] constant(1)
add.205 = s32[] add(get-tuple-element.147, constant.157)
get-tuple-element.148 = s32[] get-tuple-element(arg_tuple.146), index=1
add.158 = s32[] add(get-tuple-element.148, constant.157)
get-tuple-element.149 = u32[10,1]{1,0} get-tuple-element(arg_tuple.146), index=2
get-tuple-element.150 = u32[10,1]{1,0} get-tuple-element(arg_tuple.146), index=3
add.161 = u32[10,1]{1,0} add(get-tuple-element.149, get-tuple-element.150)
get-tuple-element.154 = u32[4]{0} get-tuple-element(arg_tuple.146), index=7
slice.159 = u32[1]{0} slice(get-tuple-element.154), slice={[0:1]}
reshape.160 = u32[] reshape(slice.159)
broadcast.162 = u32[10,1]{1,0} broadcast(reshape.160), dimensions={}
shift-left.163 = u32[10,1]{1,0} shift-left(get-tuple-element.150, broadcast.162)
constant.156 = u32[] constant(32)
subtract.164 = u32[] subtract(constant.156, reshape.160)
broadcast.165 = u32[10,1]{1,0} broadcast(subtract.164), dimensions={}
shift-right-logical.166 = u32[10,1]{1,0} shift-right-logical(get-tuple-element.150, broadcast.165)
or.167 = u32[10,1]{1,0} or(shift-left.163, shift-right-logical.166)
xor.168 = u32[10,1]{1,0} xor(add.161, or.167)
add.171 = u32[10,1]{1,0} add(add.161, xor.168)
slice.169 = u32[1]{0} slice(get-tuple-element.154), slice={[1:2]}
reshape.170 = u32[] reshape(slice.169)
broadcast.172 = u32[10,1]{1,0} broadcast(reshape.170), dimensions={}
shift-left.173 = u32[10,1]{1,0} shift-left(xor.168, broadcast.172)
subtract.174 = u32[] subtract(constant.156, reshape.170)
broadcast.175 = u32[10,1]{1,0} broadcast(subtract.174), dimensions={}
shift-right-logical.176 = u32[10,1]{1,0} shift-right-logical(xor.168, broadcast.175)
or.177 = u32[10,1]{1,0} or(shift-left.173, shift-right-logical.176)
xor.178 = u32[10,1]{1,0} xor(add.171, or.177)
add.181 = u32[10,1]{1,0} add(add.171, xor.178)
slice.179 = u32[1]{0} slice(get-tuple-element.154), slice={[2:3]}
reshape.180 = u32[] reshape(slice.179)
broadcast.182 = u32[10,1]{1,0} broadcast(reshape.180), dimensions={}
shift-left.183 = u32[10,1]{1,0} shift-left(xor.178, broadcast.182)
subtract.184 = u32[] subtract(constant.156, reshape.180)
broadcast.185 = u32[10,1]{1,0} broadcast(subtract.184), dimensions={}
shift-right-logical.186 = u32[10,1]{1,0} shift-right-logical(xor.178, broadcast.185)
or.187 = u32[10,1]{1,0} or(shift-left.183, shift-right-logical.186)
xor.188 = u32[10,1]{1,0} xor(add.181, or.187)
add.191 = u32[10,1]{1,0} add(add.181, xor.188)
get-tuple-element.151 = u32[10,1]{1,0} get-tuple-element(arg_tuple.146), index=4
add.199 = u32[10,1]{1,0} add(add.191, get-tuple-element.151)
slice.189 = u32[1]{0} slice(get-tuple-element.154), slice={[3:4]}
reshape.190 = u32[] reshape(slice.189)
broadcast.192 = u32[10,1]{1,0} broadcast(reshape.190), dimensions={}
shift-left.193 = u32[10,1]{1,0} shift-left(xor.188, broadcast.192)
subtract.194 = u32[] subtract(constant.156, reshape.190)
broadcast.195 = u32[10,1]{1,0} broadcast(subtract.194), dimensions={}
shift-right-logical.196 = u32[10,1]{1,0} shift-right-logical(xor.188, broadcast.195)
or.197 = u32[10,1]{1,0} or(shift-left.193, shift-right-logical.196)
xor.198 = u32[10,1]{1,0} xor(add.191, or.197)
get-tuple-element.152 = u32[10,1]{1,0} get-tuple-element(arg_tuple.146), index=5
add.200 = u32[10,1]{1,0} add(xor.198, get-tuple-element.152)
add.201 = s32[] add(get-tuple-element.148, constant.157)
convert.202 = u32[] convert(add.201)
broadcast.203 = u32[10,1]{1,0} broadcast(convert.202), dimensions={}
add.204 = u32[10,1]{1,0} add(add.200, broadcast.203)
get-tuple-element.153 = u32[10,1]{1,0} get-tuple-element(arg_tuple.146), index=6
get-tuple-element.155 = u32[4]{0} get-tuple-element(arg_tuple.146), index=8
ROOT tuple.206 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(add.205, add.158, add.199, add.204, get-tuple-element.152, get-tuple-element.153, get-tuple-element.151, get-tuple-element.155, get-tuple-element.154)
}
region_3.207 {
arg_tuple.208 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
get-tuple-element.210 = s32[] get-tuple-element(arg_tuple.208), index=1
get-tuple-element.211 = u32[10,1]{1,0} get-tuple-element(arg_tuple.208), index=2
get-tuple-element.212 = u32[10,1]{1,0} get-tuple-element(arg_tuple.208), index=3
get-tuple-element.213 = u32[10,1]{1,0} get-tuple-element(arg_tuple.208), index=4
get-tuple-element.214 = u32[10,1]{1,0} get-tuple-element(arg_tuple.208), index=5
get-tuple-element.215 = u32[10,1]{1,0} get-tuple-element(arg_tuple.208), index=6
get-tuple-element.216 = u32[4]{0} get-tuple-element(arg_tuple.208), index=7
get-tuple-element.217 = u32[4]{0} get-tuple-element(arg_tuple.208), index=8
get-tuple-element.209 = s32[] get-tuple-element(arg_tuple.208), index=0
constant.218 = s32[] constant(5)
ROOT compare.219 = pred[] compare(get-tuple-element.209, constant.218), direction=LT
}
region_4.220 {
arg_tuple.221 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
get-tuple-element.222 = s32[] get-tuple-element(arg_tuple.221), index=0
constant.232 = s32[] constant(1)
add.280 = s32[] add(get-tuple-element.222, constant.232)
get-tuple-element.223 = s32[] get-tuple-element(arg_tuple.221), index=1
add.233 = s32[] add(get-tuple-element.223, constant.232)
get-tuple-element.224 = u32[10,1]{1,0} get-tuple-element(arg_tuple.221), index=2
get-tuple-element.225 = u32[10,1]{1,0} get-tuple-element(arg_tuple.221), index=3
add.236 = u32[10,1]{1,0} add(get-tuple-element.224, get-tuple-element.225)
get-tuple-element.229 = u32[4]{0} get-tuple-element(arg_tuple.221), index=7
slice.234 = u32[1]{0} slice(get-tuple-element.229), slice={[0:1]}
reshape.235 = u32[] reshape(slice.234)
broadcast.237 = u32[10,1]{1,0} broadcast(reshape.235), dimensions={}
shift-left.238 = u32[10,1]{1,0} shift-left(get-tuple-element.225, broadcast.237)
constant.231 = u32[] constant(32)
subtract.239 = u32[] subtract(constant.231, reshape.235)
broadcast.240 = u32[10,1]{1,0} broadcast(subtract.239), dimensions={}
shift-right-logical.241 = u32[10,1]{1,0} shift-right-logical(get-tuple-element.225, broadcast.240)
or.242 = u32[10,1]{1,0} or(shift-left.238, shift-right-logical.241)
xor.243 = u32[10,1]{1,0} xor(add.236, or.242)
add.246 = u32[10,1]{1,0} add(add.236, xor.243)
slice.244 = u32[1]{0} slice(get-tuple-element.229), slice={[1:2]}
reshape.245 = u32[] reshape(slice.244)
broadcast.247 = u32[10,1]{1,0} broadcast(reshape.245), dimensions={}
shift-left.248 = u32[10,1]{1,0} shift-left(xor.243, broadcast.247)
subtract.249 = u32[] subtract(constant.231, reshape.245)
broadcast.250 = u32[10,1]{1,0} broadcast(subtract.249), dimensions={}
shift-right-logical.251 = u32[10,1]{1,0} shift-right-logical(xor.243, broadcast.250)
or.252 = u32[10,1]{1,0} or(shift-left.248, shift-right-logical.251)
xor.253 = u32[10,1]{1,0} xor(add.246, or.252)
add.256 = u32[10,1]{1,0} add(add.246, xor.253)
slice.254 = u32[1]{0} slice(get-tuple-element.229), slice={[2:3]}
reshape.255 = u32[] reshape(slice.254)
broadcast.257 = u32[10,1]{1,0} broadcast(reshape.255), dimensions={}
shift-left.258 = u32[10,1]{1,0} shift-left(xor.253, broadcast.257)
subtract.259 = u32[] subtract(constant.231, reshape.255)
broadcast.260 = u32[10,1]{1,0} broadcast(subtract.259), dimensions={}
shift-right-logical.261 = u32[10,1]{1,0} shift-right-logical(xor.253, broadcast.260)
or.262 = u32[10,1]{1,0} or(shift-left.258, shift-right-logical.261)
xor.263 = u32[10,1]{1,0} xor(add.256, or.262)
add.266 = u32[10,1]{1,0} add(add.256, xor.263)
get-tuple-element.226 = u32[10,1]{1,0} get-tuple-element(arg_tuple.221), index=4
add.274 = u32[10,1]{1,0} add(add.266, get-tuple-element.226)
slice.264 = u32[1]{0} slice(get-tuple-element.229), slice={[3:4]}
reshape.265 = u32[] reshape(slice.264)
broadcast.267 = u32[10,1]{1,0} broadcast(reshape.265), dimensions={}
shift-left.268 = u32[10,1]{1,0} shift-left(xor.263, broadcast.267)
subtract.269 = u32[] subtract(constant.231, reshape.265)
broadcast.270 = u32[10,1]{1,0} broadcast(subtract.269), dimensions={}
shift-right-logical.271 = u32[10,1]{1,0} shift-right-logical(xor.263, broadcast.270)
or.272 = u32[10,1]{1,0} or(shift-left.268, shift-right-logical.271)
xor.273 = u32[10,1]{1,0} xor(add.266, or.272)
get-tuple-element.227 = u32[10,1]{1,0} get-tuple-element(arg_tuple.221), index=5
add.275 = u32[10,1]{1,0} add(xor.273, get-tuple-element.227)
add.276 = s32[] add(get-tuple-element.223, constant.232)
convert.277 = u32[] convert(add.276)
broadcast.278 = u32[10,1]{1,0} broadcast(convert.277), dimensions={}
add.279 = u32[10,1]{1,0} add(add.275, broadcast.278)
get-tuple-element.228 = u32[10,1]{1,0} get-tuple-element(arg_tuple.221), index=6
get-tuple-element.230 = u32[4]{0} get-tuple-element(arg_tuple.221), index=8
ROOT tuple.281 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(add.280, add.233, add.274, add.279, get-tuple-element.227, get-tuple-element.228, get-tuple-element.226, get-tuple-element.230, get-tuple-element.229)
}
region_5.282 {
arg_tuple.283 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
get-tuple-element.285 = s32[] get-tuple-element(arg_tuple.283), index=1
get-tuple-element.286 = u32[10,1]{1,0} get-tuple-element(arg_tuple.283), index=2
get-tuple-element.287 = u32[10,1]{1,0} get-tuple-element(arg_tuple.283), index=3
get-tuple-element.288 = u32[10,1]{1,0} get-tuple-element(arg_tuple.283), index=4
get-tuple-element.289 = u32[10,1]{1,0} get-tuple-element(arg_tuple.283), index=5
get-tuple-element.290 = u32[10,1]{1,0} get-tuple-element(arg_tuple.283), index=6
get-tuple-element.291 = u32[4]{0} get-tuple-element(arg_tuple.283), index=7
get-tuple-element.292 = u32[4]{0} get-tuple-element(arg_tuple.283), index=8
get-tuple-element.284 = s32[] get-tuple-element(arg_tuple.283), index=0
constant.293 = s32[] constant(5)
ROOT compare.294 = pred[] compare(get-tuple-element.284, constant.293), direction=LT
}
_randint.295 {
constant.303 = s32[] constant(0)
Arg_0.296 = u32[10,2]{1,0} parameter(0)
call.312 = u32[10,2,2]{2,1,0} call(Arg_0.296), to_apply=_threefry_split.102
slice.313 = u32[10,1,2]{2,1,0} slice(call.312), slice={[0:10], [0:1], [0:2]}
reshape.314 = u32[10,2]{1,0} reshape(slice.313)
slice.317 = u32[10,1]{1,0} slice(reshape.314), slice={[0:10], [0:1]}
slice.318 = u32[10,1]{1,0} slice(reshape.314), slice={[0:10], [1:2]}
xor.319 = u32[10,1]{1,0} xor(slice.317, slice.318)
constant.299 = u32[] constant(466688986)
broadcast.300 = u32[10,1]{1,0} broadcast(constant.299), dimensions={}
xor.320 = u32[10,1]{1,0} xor(xor.319, broadcast.300)
constant.305 = u32[4]{0} constant({13, 15, 26, 6})
constant.304 = u32[4]{0} constant({17, 29, 16, 24})
tuple.321 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(constant.303, constant.303, slice.317, slice.318, slice.318, xor.320, slice.317, constant.305, constant.304)
while.322 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) while(tuple.321), condition=region_3.207, body=region_2.145
get-tuple-element.323 = s32[] get-tuple-element(while.322), index=0
get-tuple-element.324 = s32[] get-tuple-element(while.322), index=1
get-tuple-element.326 = u32[10,1]{1,0} get-tuple-element(while.322), index=3
get-tuple-element.327 = u32[10,1]{1,0} get-tuple-element(while.322), index=4
get-tuple-element.328 = u32[10,1]{1,0} get-tuple-element(while.322), index=5
get-tuple-element.329 = u32[10,1]{1,0} get-tuple-element(while.322), index=6
get-tuple-element.330 = u32[4]{0} get-tuple-element(while.322), index=7
get-tuple-element.331 = u32[4]{0} get-tuple-element(while.322), index=8
slice.315 = u32[10,1,2]{2,1,0} slice(call.312), slice={[0:10], [1:2], [0:2]}
reshape.316 = u32[10,2]{1,0} reshape(slice.315)
slice.333 = u32[10,1]{1,0} slice(reshape.316), slice={[0:10], [0:1]}
slice.334 = u32[10,1]{1,0} slice(reshape.316), slice={[0:10], [1:2]}
xor.335 = u32[10,1]{1,0} xor(slice.333, slice.334)
xor.336 = u32[10,1]{1,0} xor(xor.335, broadcast.300)
tuple.337 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(constant.303, constant.303, slice.333, slice.334, slice.334, xor.336, slice.333, constant.305, constant.304)
while.338 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) while(tuple.337), condition=region_5.282, body=region_4.220
get-tuple-element.339 = s32[] get-tuple-element(while.338), index=0
get-tuple-element.340 = s32[] get-tuple-element(while.338), index=1
get-tuple-element.342 = u32[10,1]{1,0} get-tuple-element(while.338), index=3
get-tuple-element.343 = u32[10,1]{1,0} get-tuple-element(while.338), index=4
get-tuple-element.344 = u32[10,1]{1,0} get-tuple-element(while.338), index=5
get-tuple-element.345 = u32[10,1]{1,0} get-tuple-element(while.338), index=6
get-tuple-element.346 = u32[4]{0} get-tuple-element(while.338), index=7
get-tuple-element.347 = u32[4]{0} get-tuple-element(while.338), index=8
Arg_1.297 = s32[] parameter(1)
constant.307 = s32[] constant(-2147483648)
constant.306 = s32[] constant(2147483647)
call.310 = s32[] call(Arg_1.297, constant.307, constant.306), to_apply=clip_0.9
broadcast.370 = s32[10]{0} broadcast(call.310), dimensions={}
get-tuple-element.325 = u32[10,1]{1,0} get-tuple-element(while.322), index=2
reshape.332 = u32[10]{0} reshape(get-tuple-element.325)
Arg_2.298 = s32[] parameter(2)
call.308 = s32[] call(constant.306, constant.307, constant.306), to_apply=clip.3
compare.309 = pred[] compare(Arg_2.298, call.308), direction=GT
call.311 = s32[] call(Arg_2.298, constant.307, constant.306), to_apply=clip_0.15
compare.353 = pred[] compare(call.311, call.310), direction=GT
and.354 = pred[] and(compare.309, compare.353)
compare.351 = pred[] compare(call.311, call.310), direction=LE
constant.302 = u32[] constant(1)
subtract.349 = s32[] subtract(call.311, call.310)
convert.350 = u32[] convert(subtract.349)
select.352 = u32[] select(compare.351, constant.302, convert.350)
add.355 = u32[] add(select.352, constant.302)
select.356 = u32[] select(and.354, add.355, select.352)
broadcast.360 = u32[10]{0} broadcast(select.356), dimensions={}
remainder.361 = u32[10]{0} remainder(reshape.332, broadcast.360)
constant.301 = u32[] constant(65536)
remainder.357 = u32[] remainder(constant.301, select.356)
multiply.358 = u32[] multiply(remainder.357, remainder.357)
remainder.359 = u32[] remainder(multiply.358, select.356)
broadcast.362 = u32[10]{0} broadcast(remainder.359), dimensions={}
multiply.363 = u32[10]{0} multiply(remainder.361, broadcast.362)
get-tuple-element.341 = u32[10,1]{1,0} get-tuple-element(while.338), index=2
reshape.348 = u32[10]{0} reshape(get-tuple-element.341)
broadcast.364 = u32[10]{0} broadcast(select.356), dimensions={}
remainder.365 = u32[10]{0} remainder(reshape.348, broadcast.364)
add.366 = u32[10]{0} add(multiply.363, remainder.365)
broadcast.367 = u32[10]{0} broadcast(select.356), dimensions={}
remainder.368 = u32[10]{0} remainder(add.366, broadcast.367)
convert.369 = s32[10]{0} convert(remainder.368)
ROOT add.371 = s32[10]{0} add(broadcast.370, convert.369)
}
region_7.372 {
arg_tuple.373 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
get-tuple-element.374 = s32[] get-tuple-element(arg_tuple.373), index=0
constant.384 = s32[] constant(1)
add.432 = s32[] add(get-tuple-element.374, constant.384)
get-tuple-element.375 = s32[] get-tuple-element(arg_tuple.373), index=1
add.385 = s32[] add(get-tuple-element.375, constant.384)
get-tuple-element.376 = u32[10,1]{1,0} get-tuple-element(arg_tuple.373), index=2
get-tuple-element.377 = u32[10,1]{1,0} get-tuple-element(arg_tuple.373), index=3
add.388 = u32[10,1]{1,0} add(get-tuple-element.376, get-tuple-element.377)
get-tuple-element.381 = u32[4]{0} get-tuple-element(arg_tuple.373), index=7
slice.386 = u32[1]{0} slice(get-tuple-element.381), slice={[0:1]}
reshape.387 = u32[] reshape(slice.386)
broadcast.389 = u32[10,1]{1,0} broadcast(reshape.387), dimensions={}
shift-left.390 = u32[10,1]{1,0} shift-left(get-tuple-element.377, broadcast.389)
constant.383 = u32[] constant(32)
subtract.391 = u32[] subtract(constant.383, reshape.387)
broadcast.392 = u32[10,1]{1,0} broadcast(subtract.391), dimensions={}
shift-right-logical.393 = u32[10,1]{1,0} shift-right-logical(get-tuple-element.377, broadcast.392)
or.394 = u32[10,1]{1,0} or(shift-left.390, shift-right-logical.393)
xor.395 = u32[10,1]{1,0} xor(add.388, or.394)
add.398 = u32[10,1]{1,0} add(add.388, xor.395)
slice.396 = u32[1]{0} slice(get-tuple-element.381), slice={[1:2]}
reshape.397 = u32[] reshape(slice.396)
broadcast.399 = u32[10,1]{1,0} broadcast(reshape.397), dimensions={}
shift-left.400 = u32[10,1]{1,0} shift-left(xor.395, broadcast.399)
subtract.401 = u32[] subtract(constant.383, reshape.397)
broadcast.402 = u32[10,1]{1,0} broadcast(subtract.401), dimensions={}
shift-right-logical.403 = u32[10,1]{1,0} shift-right-logical(xor.395, broadcast.402)
or.404 = u32[10,1]{1,0} or(shift-left.400, shift-right-logical.403)
xor.405 = u32[10,1]{1,0} xor(add.398, or.404)
add.408 = u32[10,1]{1,0} add(add.398, xor.405)
slice.406 = u32[1]{0} slice(get-tuple-element.381), slice={[2:3]}
reshape.407 = u32[] reshape(slice.406)
broadcast.409 = u32[10,1]{1,0} broadcast(reshape.407), dimensions={}
shift-left.410 = u32[10,1]{1,0} shift-left(xor.405, broadcast.409)
subtract.411 = u32[] subtract(constant.383, reshape.407)
broadcast.412 = u32[10,1]{1,0} broadcast(subtract.411), dimensions={}
shift-right-logical.413 = u32[10,1]{1,0} shift-right-logical(xor.405, broadcast.412)
or.414 = u32[10,1]{1,0} or(shift-left.410, shift-right-logical.413)
xor.415 = u32[10,1]{1,0} xor(add.408, or.414)
add.418 = u32[10,1]{1,0} add(add.408, xor.415)
get-tuple-element.378 = u32[10,1]{1,0} get-tuple-element(arg_tuple.373), index=4
add.426 = u32[10,1]{1,0} add(add.418, get-tuple-element.378)
slice.416 = u32[1]{0} slice(get-tuple-element.381), slice={[3:4]}
reshape.417 = u32[] reshape(slice.416)
broadcast.419 = u32[10,1]{1,0} broadcast(reshape.417), dimensions={}
shift-left.420 = u32[10,1]{1,0} shift-left(xor.415, broadcast.419)
subtract.421 = u32[] subtract(constant.383, reshape.417)
broadcast.422 = u32[10,1]{1,0} broadcast(subtract.421), dimensions={}
shift-right-logical.423 = u32[10,1]{1,0} shift-right-logical(xor.415, broadcast.422)
or.424 = u32[10,1]{1,0} or(shift-left.420, shift-right-logical.423)
xor.425 = u32[10,1]{1,0} xor(add.418, or.424)
get-tuple-element.379 = u32[10,1]{1,0} get-tuple-element(arg_tuple.373), index=5
add.427 = u32[10,1]{1,0} add(xor.425, get-tuple-element.379)
add.428 = s32[] add(get-tuple-element.375, constant.384)
convert.429 = u32[] convert(add.428)
broadcast.430 = u32[10,1]{1,0} broadcast(convert.429), dimensions={}
add.431 = u32[10,1]{1,0} add(add.427, broadcast.430)
get-tuple-element.380 = u32[10,1]{1,0} get-tuple-element(arg_tuple.373), index=6
get-tuple-element.382 = u32[4]{0} get-tuple-element(arg_tuple.373), index=8
ROOT tuple.433 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(add.432, add.385, add.426, add.431, get-tuple-element.379, get-tuple-element.380, get-tuple-element.378, get-tuple-element.382, get-tuple-element.381)
}
region_8.434 {
arg_tuple.435 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) parameter(0)
get-tuple-element.437 = s32[] get-tuple-element(arg_tuple.435), index=1
get-tuple-element.438 = u32[10,1]{1,0} get-tuple-element(arg_tuple.435), index=2
get-tuple-element.439 = u32[10,1]{1,0} get-tuple-element(arg_tuple.435), index=3
get-tuple-element.440 = u32[10,1]{1,0} get-tuple-element(arg_tuple.435), index=4
get-tuple-element.441 = u32[10,1]{1,0} get-tuple-element(arg_tuple.435), index=5
get-tuple-element.442 = u32[10,1]{1,0} get-tuple-element(arg_tuple.435), index=6
get-tuple-element.443 = u32[4]{0} get-tuple-element(arg_tuple.435), index=7
get-tuple-element.444 = u32[4]{0} get-tuple-element(arg_tuple.435), index=8
get-tuple-element.436 = s32[] get-tuple-element(arg_tuple.435), index=0
constant.445 = s32[] constant(5)
ROOT compare.446 = pred[] compare(get-tuple-element.436, constant.445), direction=LT
}
_uniform.447 {
constant.459 = s32[] constant(0)
Arg_0.448 = u32[10,2]{1,0} parameter(0)
slice.464 = u32[10,1]{1,0} slice(Arg_0.448), slice={[0:10], [0:1]}
slice.465 = u32[10,1]{1,0} slice(Arg_0.448), slice={[0:10], [1:2]}
xor.466 = u32[10,1]{1,0} xor(slice.464, slice.465)
constant.457 = u32[] constant(466688986)
broadcast.458 = u32[10,1]{1,0} broadcast(constant.457), dimensions={}
xor.467 = u32[10,1]{1,0} xor(xor.466, broadcast.458)
constant.461 = u32[4]{0} constant({13, 15, 26, 6})
constant.460 = u32[4]{0} constant({17, 29, 16, 24})
tuple.468 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) tuple(constant.459, constant.459, slice.464, slice.465, slice.465, xor.467, slice.464, constant.461, constant.460)
while.469 = (s32[], s32[], u32[10,1]{1,0}, u32[10,1]{1,0}, u32[10,1]{1,0}, /*index=5*/u32[10,1]{1,0}, u32[10,1]{1,0}, u32[4]{0}, u32[4]{0}) while(tuple.468), condition=region_8.434, body=region_7.372
get-tuple-element.470 = s32[] get-tuple-element(while.469), index=0
get-tuple-element.471 = s32[] get-tuple-element(while.469), index=1
get-tuple-element.473 = u32[10,1]{1,0} get-tuple-element(while.469), index=3
get-tuple-element.474 = u32[10,1]{1,0} get-tuple-element(while.469), index=4
get-tuple-element.475 = u32[10,1]{1,0} get-tuple-element(while.469), index=5
get-tuple-element.476 = u32[10,1]{1,0} get-tuple-element(while.469), index=6
get-tuple-element.477 = u32[4]{0} get-tuple-element(while.469), index=7
get-tuple-element.478 = u32[4]{0} get-tuple-element(while.469), index=8
Arg_1.449 = f32[] parameter(1)
reshape.494 = f32[1,1]{1,0} reshape(Arg_1.449)
broadcast.495 = f32[1,1]{1,0} broadcast(reshape.494), dimensions={0,1}
reshape.496 = f32[1]{0} reshape(broadcast.495)
broadcast.497 = f32[10,1]{1,0} broadcast(reshape.496), dimensions={1}
get-tuple-element.472 = u32[10,1]{1,0} get-tuple-element(while.469), index=2
constant.455 = u32[] constant(9)
broadcast.456 = u32[10,1]{1,0} broadcast(constant.455), dimensions={}
shift-right-logical.479 = u32[10,1]{1,0} shift-right-logical(get-tuple-element.472, broadcast.456)
constant.453 = u32[] constant(1065353216)
broadcast.454 = u32[10,1]{1,0} broadcast(constant.453), dimensions={}
or.480 = u32[10,1]{1,0} or(shift-right-logical.479, broadcast.454)
bitcast-convert.481 = f32[10,1]{1,0} bitcast-convert(or.480)
constant.451 = f32[] constant(1)
broadcast.452 = f32[10,1]{1,0} broadcast(constant.451), dimensions={}
subtract.482 = f32[10,1]{1,0} subtract(bitcast-convert.481, broadcast.452)
Arg_2.450 = f32[] parameter(2)
reshape.463 = f32[1]{0} reshape(Arg_2.450)
reshape.462 = f32[1]{0} reshape(Arg_1.449)
subtract.483 = f32[1]{0} subtract(reshape.463, reshape.462)
reshape.484 = f32[1,1]{1,0} reshape(subtract.483)
broadcast.485 = f32[1,1]{1,0} broadcast(reshape.484), dimensions={0,1}
reshape.486 = f32[1]{0} reshape(broadcast.485)
broadcast.487 = f32[10,1]{1,0} broadcast(reshape.486), dimensions={1}
multiply.488 = f32[10,1]{1,0} multiply(subtract.482, broadcast.487)
reshape.489 = f32[1,1]{1,0} reshape(Arg_1.449)
broadcast.490 = f32[1,1]{1,0} broadcast(reshape.489), dimensions={0,1}
reshape.491 = f32[1]{0} reshape(broadcast.490)
broadcast.492 = f32[10,1]{1,0} broadcast(reshape.491), dimensions={1}
add.493 = f32[10,1]{1,0} add(multiply.488, broadcast.492)
ROOT maximum.498 = f32[10,1]{1,0} maximum(broadcast.497, add.493)
}
region_9.499 {
arg_tuple.500 = (f32[10]{0}, f32[10,1]{1,0}) parameter(0)
get-tuple-element.502 = f32[10,1]{1,0} get-tuple-element(arg_tuple.500), index=1
reshape.511 = f32[10]{0} reshape(get-tuple-element.502)
constant.503 = f32[] constant(2)
broadcast.504 = f32[10]{0} broadcast(constant.503), dimensions={}
compare.512 = pred[10]{0} compare(reshape.511, broadcast.504), direction=LT
get-tuple-element.501 = f32[10]{0} get-tuple-element(arg_tuple.500), index=0
constant.507 = f32[] constant(0.1)
broadcast.508 = f32[10]{0} broadcast(constant.507), dimensions={}
add.509 = f32[10]{0} add(get-tuple-element.501, broadcast.508)
select.513 = f32[10]{0} select(compare.512, add.509, get-tuple-element.501)
reshape.514 = pred[10,1]{1,0} reshape(compare.512)
constant.505 = f32[] constant(0.1)
broadcast.506 = f32[10,1]{1,0} broadcast(constant.505), dimensions={}
add.510 = f32[10,1]{1,0} add(get-tuple-element.502, broadcast.506)
select.515 = f32[10,1]{1,0} select(reshape.514, add.510, get-tuple-element.502)
ROOT tuple.516 = (f32[10]{0}, f32[10,1]{1,0}) tuple(select.513, select.515)
}
region_11.517 {
Arg_0.518 = pred[] parameter(0)
Arg_1.519 = pred[] parameter(1)
ROOT or.520 = pred[] or(Arg_0.518, Arg_1.519)
}
region_10.521 {
arg_tuple.522 = (f32[10]{0}, f32[10,1]{1,0}) parameter(0)
get-tuple-element.523 = f32[10]{0} get-tuple-element(arg_tuple.522), index=0
get-tuple-element.524 = f32[10,1]{1,0} get-tuple-element(arg_tuple.522), index=1
reshape.528 = f32[10]{0} reshape(get-tuple-element.524)
constant.526 = f32[] constant(2)
broadcast.527 = f32[10]{0} broadcast(constant.526), dimensions={}
compare.529 = pred[10]{0} compare(reshape.528, broadcast.527), direction=LT
constant.525 = pred[] constant(false)
ROOT reduce.530 = pred[] reduce(compare.529, constant.525), dimensions={0}, to_apply=region_11.517
}
region_6.531 {
arg_tuple.532 = (f32[10]{0}, f32[10,1]{1,0}, s32[10]{0}) parameter(0)
get-tuple-element.535 = s32[10]{0} get-tuple-element(arg_tuple.532), index=2
constant.536 = s32[] constant(5)
broadcast.537 = s32[10]{0} broadcast(constant.536), dimensions={}
compare.556 = pred[10]{0} compare(get-tuple-element.535, broadcast.537), direction=LT
get-tuple-element.533 = f32[10]{0} get-tuple-element(arg_tuple.532), index=0
constant.542 = s32[] constant(32)
broadcast.543 = s32[10]{0} broadcast(constant.542), dimensions={}
shift-right-logical.544 = s32[10]{0} shift-right-logical(get-tuple-element.535, broadcast.543)
convert.545 = u32[10]{0} convert(shift-right-logical.544)
reshape.546 = u32[10,1]{1,0} reshape(convert.545)
convert.547 = u32[10]{0} convert(get-tuple-element.535)
reshape.548 = u32[10,1]{1,0} reshape(convert.547)
concatenate.549 = u32[10,2]{1,0} concatenate(reshape.546, reshape.548), dimensions={1}
constant.541 = f32[] constant(0)
constant.540 = f32[] constant(1)
call.550 = f32[10,1]{1,0} call(concatenate.549, constant.541, constant.540), to_apply=_uniform.447
tuple.551 = (f32[10]{0}, f32[10,1]{1,0}) tuple(get-tuple-element.533, call.550)
while.552 = (f32[10]{0}, f32[10,1]{1,0}) while(tuple.551), condition=region_10.521, body=region_9.499
get-tuple-element.553 = f32[10]{0} get-tuple-element(while.552), index=0
select.557 = f32[10]{0} select(compare.556, get-tuple-element.553, get-tuple-element.533)
reshape.558 = pred[10,1]{1,0} reshape(compare.556)
get-tuple-element.554 = f32[10,1]{1,0} get-tuple-element(while.552), index=1
get-tuple-element.534 = f32[10,1]{1,0} get-tuple-element(arg_tuple.532), index=1
select.559 = f32[10,1]{1,0} select(reshape.558, get-tuple-element.554, get-tuple-element.534)
constant.538 = s32[] constant(1)
broadcast.539 = s32[10]{0} broadcast(constant.538), dimensions={}
add.555 = s32[10]{0} add(get-tuple-element.535, broadcast.539)
select.560 = s32[10]{0} select(compare.556, add.555, get-tuple-element.535)
ROOT tuple.561 = (f32[10]{0}, f32[10,1]{1,0}, s32[10]{0}) tuple(select.557, select.559, select.560)
}
region_13.562 {
Arg_0.563 = pred[] parameter(0)
Arg_1.564 = pred[] parameter(1)
ROOT or.565 = pred[] or(Arg_0.563, Arg_1.564)
}
region_12.566 {
arg_tuple.567 = (f32[10]{0}, f32[10,1]{1,0}, s32[10]{0}) parameter(0)
get-tuple-element.568 = f32[10]{0} get-tuple-element(arg_tuple.567), index=0
get-tuple-element.569 = f32[10,1]{1,0} get-tuple-element(arg_tuple.567), index=1
get-tuple-element.570 = s32[10]{0} get-tuple-element(arg_tuple.567), index=2
constant.572 = s32[] constant(5)
broadcast.573 = s32[10]{0} broadcast(constant.572), dimensions={}
compare.574 = pred[10]{0} compare(get-tuple-element.570, broadcast.573), direction=LT
constant.571 = pred[] constant(false)
ROOT reduce.575 = pred[] reduce(compare.574, constant.571), dimensions={0}, to_apply=region_13.562
}
solve.576 {
constant.579 = f32[] constant(0)
broadcast.580 = f32[10]{0} broadcast(constant.579), dimensions={}
Arg_0.577 = f32[10,1]{1,0} parameter(0)
Arg_1.578 = u32[10,2]{1,0} parameter(1)
constant.581 = s32[] constant(-2)
constant.582 = s32[] constant(2)
call.583 = s32[10]{0} call(Arg_1.578, constant.581, constant.582), to_apply=_randint.295
tuple.584 = (f32[10]{0}, f32[10,1]{1,0}, s32[10]{0}) tuple(broadcast.580, Arg_0.577, call.583)
while.585 = (f32[10]{0}, f32[10,1]{1,0}, s32[10]{0}) while(tuple.584), condition=region_12.566, body=region_6.531
get-tuple-element.586 = f32[10]{0} get-tuple-element(while.585), index=0
ROOT get-tuple-element.587 = f32[10,1]{1,0} get-tuple-element(while.585), index=1
get-tuple-element.588 = s32[10]{0} get-tuple-element(while.585), index=2
}
ENTRY main.591 {
Arg_0.1 = f32[10,1]{1,0} parameter(0)
Arg_1.2 = u32[10,2]{1,0} parameter(1)
call.589 = f32[10,1]{1,0} call(Arg_0.1, Arg_1.2), to_apply=solve.576
ROOT tuple.590 = (f32[10,1]{1,0}) tuple(call.589)
}
For performance-related things like this it is usually in XLA. JAX is mostly at the mercy of whatever code XLA generates.
Unfortunately the parallelism part of this isn't something I'm familiar with at all. I think @sharadmv might know more? This one is out of my wheelhouse I'm afraid.
Description
As the title indicated, with a double while loop, where the inner while loop may change in length over outer while loop steps,
pmap
is substantially faster than sharding. This may sound contrived, but is exactly what happens in other packages, such as diffrax where I first identified this issue: https://github.com/patrick-kidger/diffrax/issues/407. I believe there are two possibilities, 1) I am using sharding wrong and that is why it is slow (very possible, I am new to sharding), 2) something else is going on in sharding.I have included a MVC below. I ran on both CPU and GPU and the results on GPU are even more noticeable.
CPU:
5.11 ms ± 53.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
GPU:1.18 s ± 48.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
CPU:
251 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
GPU:3.93 ms ± 225 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
System info (python version, jaxlib version, accelerator, etc.)
CPU:
GPU: