jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

Unexpected speedup from wrapping function call in trivial jax.lax.cond statement #21065

Open TonyZhou729 opened 6 months ago

TonyZhou729 commented 6 months ago

Description

Hi,

We noticed a strange speed-up when a trivial lax.cond statement is used to call a function rather than directly calling a function itself.

In the reproduction of the issue below, we use JIT on a main() function which contains a lax.scan() loop. In each loop call, if we insert a lax.cond() around the function we call with the condition that the loop index i (runs from 0 to Ny for Ny steps) is greater than -1, which is always true. This seemingly unnecessary choice somehow causes a speed up.

import jax.numpy as jnp
from jax import jit, lax
from jax.scipy.ndimage import map_coordinates
import time

Nx = 300
Ny = 100000
x_axis = jnp.linspace(5., 12.75, Nx)

@jit
def main():
    y_axis = jnp.linspace(0, 1, Ny)

    # Initial value of B is just (Nx, Ny) size arrays of zeros.
    B = jnp.zeros((Nx, Ny), dtype="float32")

    def loop_in_main(carry, i):
        B = carry
        y = y_axis[i]

        """ Obtain an array A using interp_A_from_B(), picking one of three ways """
        # Case 1: We simply run interp_A_from_B() every step
        A = interp_A_from_B((y, y_axis, B))

        # Case 2: We use a seemingly trivial lax.cond wrapper, but will still always run
        # interp_A_from_B since index i is always greater than -1.
        # For some reason we observe a speed up over case 1.
        #A = lax.cond(i>-1, interp_A_from_B, false_func, (y, y_axis, B))

        # Update B array with values of A from this loop.
        B = set_B_to_A(i, B, A)

        return B, None

    # Use lax.scan to run loop and update B Ny times.
    # Index i will run through jnp.arange(Ny) = (0, 1, 2, ..., Ny-1)
    B, _ = lax.scan(loop_in_main, B, jnp.arange(Ny))

    return B

def interp_A_from_B(params):
    # B is a (Nx, Ny) array.
    # A is a (Nx,) array.

    y, y_axis, B    = params
    # Precise value of y to interpolate at.
    y_prime         = y - jnp.log(x_axis[1:Nx] / x_axis[:Nx-1])
    # Convert to index position within y_axis, to use with ndimage.map_coordinates.
    y_prime_indices = jnp.interp(y_prime, y_axis, jnp.arange(Ny))
    # Interpolated version of A from B via 2D map_coordinates.
    interp          = map_coordinates(B, [jnp.arange(1, Nx), y_prime_indices], order=1)
    # Here, only use the interpolated result for values of y_prime larger than the smallest y in     y_axis.
    condition       = y_prime < y_axis[0]

    # Put A array together, with some fill in values for where we don't want the interpolated value.
    A               = condition * jnp.exp(-x_axis[:Nx-1]) \
                    + (1-condition) * interp
    A               = jnp.append(A, jnp.exp(-x_axis[-1]))

    return A

def set_B_to_A(i, B, A):
    # Update a column of B with the current value of A.
    B = B.at[:, i].set(A)
    return B

def false_func(params):
    # Trivial false function, sets all entries of A to some fill values if called.
    A = jnp.exp(-x_axis)
    return A

""" Running main() a couple times to see the speed """
for i in range(5):
    s = time.time()
    B = main()
    print(time.time() - s)

When using case 1 in loop_in_main() and calling main() 5 times we observe runtimes of (in seconds)

2.593395233154297
2.245166540145874
2.2276456356048584
2.242725372314453
2.2277612686157227

But switching to case 2 we see

2.095738649368286
1.769906997680664
1.7625515460968018
1.7623748779296875
1.7783701419830322

In both cases the first run time is longer due to JIT compilation. We checked that this speed up scales with Ny, the number of steps in lax.scan. In our code with more computations in each step the speed up is even more significant.

Thank you in advance for your help and comments! Tony

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.16
jaxlib: 0.4.16
numpy:  1.24.3
python: 3.10.10 (main, Mar 21 2023, 18:45:11) [GCC 11.2.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
cgiovanetti commented 2 weeks ago

Just a note that I just ran this example with jax==0.4.34 and the issue persists

jakevdp commented 2 weeks ago

Sorry I missed this originally – first of all, when you are running these kinds of microbenchmarks, be sure to follow the recommendations at FAQ: Benchmarking JAX code. In particular, you should wrap the computation of interest in jax.block_until_ready to ensure you're measuring computation time rather than just dispatch time:

for i in range(5):
    s = time.time()
    B = jax.block_until_ready(main())
    print(time.time() - s)

Still, even with this I'm seeing the same general behavior you are. It looks like something about adding the cond leads the XLA compiler to use different fusions which lead to different computation characteristics. You can see this by using Ahead-of-time compilation tools to output the compiled HLO:

print(main.lower().compile().as_text())

For case 1, the output is this:

Click to expand ``` HloModule jit_main1, is_scheduled=true, entry_computation_layout={()->f32[300,100000]{1,0}}, allow_spmd_sharding_propagation_to_output={true} %fused_computation (param_0.2: s32[299], param_1.2: s32[299], param_2.3: f32[299,1], param_3.8: s32[299]) -> s32[299] { %param_1.2 = s32[299]{0} parameter(1) %param_2.3 = f32[299,1]{1,0} parameter(2) %bitcast.14 = f32[299]{0} bitcast(f32[299,1]{1,0} %param_2.3), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/gather" source_file="" source_line=58} %compare.351 = pred[299]{0} compare(f32[299]{0} %bitcast.14, f32[299]{0} %bitcast.14), direction=NE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/ne" source_file="" source_line=58} %constant.492 = f32[] constant(nan) %broadcast.508 = f32[299]{0} broadcast(f32[] %constant.492), dimensions={} %constant.491 = f32[] constant(0) %broadcast.507 = f32[299]{0} broadcast(f32[] %constant.491), dimensions={} %compare.350 = pred[299]{0} compare(f32[299]{0} %bitcast.14, f32[299]{0} %broadcast.507), direction=EQ, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/eq" source_file="" source_line=58} %select.339 = f32[299]{0} select(pred[299]{0} %compare.350, f32[299]{0} %broadcast.507, f32[299]{0} %bitcast.14), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %select.338 = f32[299]{0} select(pred[299]{0} %compare.351, f32[299]{0} %broadcast.508, f32[299]{0} %select.339), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %bitcast-convert.17 = s32[299]{0} bitcast-convert(f32[299]{0} %select.338) %constant.493 = s32[] constant(0) %broadcast.510 = s32[299]{0} broadcast(s32[] %constant.493), dimensions={} %compare.349 = pred[299]{0} compare(s32[299]{0} %bitcast-convert.17, s32[299]{0} %broadcast.510), direction=LT %constant.490 = s32[] constant(2147483647) %broadcast.506 = s32[299]{0} broadcast(s32[] %constant.490), dimensions={} %xor.17 = s32[299]{0} xor(s32[299]{0} %broadcast.506, s32[299]{0} %bitcast-convert.17) %select.337 = s32[299]{0} select(pred[299]{0} %compare.349, s32[299]{0} %xor.17, s32[299]{0} %bitcast-convert.17) %compare.348 = pred[299]{0} compare(s32[299]{0} %param_1.2, s32[299]{0} %select.337), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/lt_to" source_file="" source_line=58} %param_3.8 = s32[299]{0} parameter(3) %param_0.2 = s32[299]{0} parameter(0) %add.291 = s32[299]{0} add(s32[299]{0} %param_3.8, s32[299]{0} %param_0.2), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/add" source_file="" source_line=58} %sign.13 = s32[299]{0} sign(s32[299]{0} %add.291), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sign" source_file="" source_line=58} %constant.494 = s32[] constant(1) %broadcast.511 = s32[299]{0} broadcast(s32[] %constant.494), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %compare.355 = pred[299]{0} compare(s32[299]{0} %sign.13, s32[299]{0} %broadcast.511), direction=NE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %compare.353 = pred[299]{0} compare(s32[299]{0} %add.291, s32[299]{0} %broadcast.510), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.56 = s32[299]{0} negate(s32[299]{0} %add.291), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.345 = s32[299]{0} select(pred[299]{0} %compare.353, s32[299]{0} %negate.56, s32[299]{0} %add.291), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %and.79 = s32[299]{0} and(s32[299]{0} %select.345, s32[299]{0} %broadcast.511), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.55 = s32[299]{0} negate(s32[299]{0} %and.79), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.343 = s32[299]{0} select(pred[299]{0} %compare.353, s32[299]{0} %negate.55, s32[299]{0} %and.79), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %compare.352 = pred[299]{0} compare(s32[299]{0} %select.343, s32[299]{0} %broadcast.510), direction=NE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %and.78 = pred[299]{0} and(pred[299]{0} %compare.355, pred[299]{0} %compare.352), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/and" source_file="" source_line=58} %shift-right-logical.6 = s32[299]{0} shift-right-logical(s32[299]{0} %select.345, s32[299]{0} %broadcast.511), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %negate.54 = s32[299]{0} negate(s32[299]{0} %shift-right-logical.6), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %select.342 = s32[299]{0} select(pred[299]{0} %compare.353, s32[299]{0} %negate.54, s32[299]{0} %shift-right-logical.6), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %constant.231.clone.2.clone.6 = s32[] constant(-1) %broadcast.509 = s32[299]{0} broadcast(s32[] %constant.231.clone.2.clone.6), dimensions={} %add.289 = s32[299]{0} add(s32[299]{0} %select.342, s32[299]{0} %broadcast.509), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sub" source_file="" source_line=58} %select.340 = s32[299]{0} select(pred[299]{0} %and.78, s32[299]{0} %add.289, s32[299]{0} %select.342), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/jit(_where)/select_n" source_file="" source_line=58} ROOT %select.336 = s32[299]{0} select(pred[299]{0} %compare.348, s32[299]{0} %select.340, s32[299]{0} %param_0.2), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(_where)/select_n" source_file="" source_line=58} } %fused_computation.1 (param_0.4: s32[299], param_1.5: s32[299], param_2.7: f32[299,1], param_3.17: s32[299]) -> s32[299] { %param_1.5 = s32[299]{0} parameter(1) %param_2.7 = f32[299,1]{1,0} parameter(2) %bitcast.15 = f32[299]{0} bitcast(f32[299,1]{1,0} %param_2.7), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/gather" source_file="" source_line=58} %compare.360 = pred[299]{0} compare(f32[299]{0} %bitcast.15, f32[299]{0} %bitcast.15), direction=NE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/ne" source_file="" source_line=58} %constant.497 = f32[] constant(nan) %broadcast.514 = f32[299]{0} broadcast(f32[] %constant.497), dimensions={} %constant.496 = f32[] constant(0) %broadcast.513 = f32[299]{0} broadcast(f32[] %constant.496), dimensions={} %compare.359 = pred[299]{0} compare(f32[299]{0} %bitcast.15, f32[299]{0} %broadcast.513), direction=EQ, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/eq" source_file="" source_line=58} %select.349 = f32[299]{0} select(pred[299]{0} %compare.359, f32[299]{0} %broadcast.513, f32[299]{0} %bitcast.15), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %select.348 = f32[299]{0} select(pred[299]{0} %compare.360, f32[299]{0} %broadcast.514, f32[299]{0} %select.349), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %bitcast-convert.18 = s32[299]{0} bitcast-convert(f32[299]{0} %select.348) %constant.498 = s32[] constant(0) %broadcast.516 = s32[299]{0} broadcast(s32[] %constant.498), dimensions={} %compare.358 = pred[299]{0} compare(s32[299]{0} %bitcast-convert.18, s32[299]{0} %broadcast.516), direction=LT %constant.495 = s32[] constant(2147483647) %broadcast.512 = s32[299]{0} broadcast(s32[] %constant.495), dimensions={} %xor.18 = s32[299]{0} xor(s32[299]{0} %broadcast.512, s32[299]{0} %bitcast-convert.18) %select.347 = s32[299]{0} select(pred[299]{0} %compare.358, s32[299]{0} %xor.18, s32[299]{0} %bitcast-convert.18) %compare.356 = pred[299]{0} compare(s32[299]{0} %param_1.5, s32[299]{0} %select.347), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/lt_to" source_file="" source_line=58} %param_0.4 = s32[299]{0} parameter(0) %param_3.17 = s32[299]{0} parameter(3) %add.293 = s32[299]{0} add(s32[299]{0} %param_0.4, s32[299]{0} %param_3.17), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/add" source_file="" source_line=58} %sign.14 = s32[299]{0} sign(s32[299]{0} %add.293), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sign" source_file="" source_line=58} %constant.499 = s32[] constant(1) %broadcast.517 = s32[299]{0} broadcast(s32[] %constant.499), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %compare.363 = pred[299]{0} compare(s32[299]{0} %sign.14, s32[299]{0} %broadcast.517), direction=NE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %compare.362 = pred[299]{0} compare(s32[299]{0} %add.293, s32[299]{0} %broadcast.516), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.59 = s32[299]{0} negate(s32[299]{0} %add.293), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.353 = s32[299]{0} select(pred[299]{0} %compare.362, s32[299]{0} %negate.59, s32[299]{0} %add.293), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %and.81 = s32[299]{0} and(s32[299]{0} %select.353, s32[299]{0} %broadcast.517), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.58 = s32[299]{0} negate(s32[299]{0} %and.81), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.352 = s32[299]{0} select(pred[299]{0} %compare.362, s32[299]{0} %negate.58, s32[299]{0} %and.81), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %compare.361 = pred[299]{0} compare(s32[299]{0} %select.352, s32[299]{0} %broadcast.516), direction=NE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %and.80 = pred[299]{0} and(pred[299]{0} %compare.363, pred[299]{0} %compare.361), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/and" source_file="" source_line=58} %shift-right-logical.7 = s32[299]{0} shift-right-logical(s32[299]{0} %select.353, s32[299]{0} %broadcast.517), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %negate.57 = s32[299]{0} negate(s32[299]{0} %shift-right-logical.7), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %select.351 = s32[299]{0} select(pred[299]{0} %compare.362, s32[299]{0} %negate.57, s32[299]{0} %shift-right-logical.7), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %constant.231.clone.2.clone.7 = s32[] constant(-1) %broadcast.515 = s32[299]{0} broadcast(s32[] %constant.231.clone.2.clone.7), dimensions={} %add.292 = s32[299]{0} add(s32[299]{0} %select.351, s32[299]{0} %broadcast.515), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sub" source_file="" source_line=58} %select.350 = s32[299]{0} select(pred[299]{0} %and.80, s32[299]{0} %add.292, s32[299]{0} %select.351), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/jit(_where)/select_n" source_file="" source_line=58} ROOT %select.346 = s32[299]{0} select(pred[299]{0} %compare.356, s32[299]{0} %param_0.4, s32[299]{0} %select.350), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(_where)/select_n" source_file="" source_line=58} } %fused_computation.2 (param_0.5: f32[100000], param_1.14: s32[299], param_2.19: s32[299]) -> f32[299,1] { %param_0.5 = f32[100000]{0} parameter(0) %param_1.14 = s32[299]{0} parameter(1) %param_2.19 = s32[299]{0} parameter(2) %add.296 = s32[299]{0} add(s32[299]{0} %param_1.14, s32[299]{0} %param_2.19), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/add" source_file="" source_line=58} %sign.15 = s32[299]{0} sign(s32[299]{0} %add.296), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sign" source_file="" source_line=58} %constant.502 = s32[] constant(1) %broadcast.521 = s32[299]{0} broadcast(s32[] %constant.502), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %compare.367 = pred[299]{0} compare(s32[299]{0} %sign.15, s32[299]{0} %broadcast.521), direction=NE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %constant.501 = s32[] constant(0) %broadcast.520 = s32[299]{0} broadcast(s32[] %constant.501), dimensions={} %compare.366 = pred[299]{0} compare(s32[299]{0} %add.296, s32[299]{0} %broadcast.520), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.62 = s32[299]{0} negate(s32[299]{0} %add.296), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.360 = s32[299]{0} select(pred[299]{0} %compare.366, s32[299]{0} %negate.62, s32[299]{0} %add.296), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %and.83 = s32[299]{0} and(s32[299]{0} %select.360, s32[299]{0} %broadcast.521), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.61 = s32[299]{0} negate(s32[299]{0} %and.83), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.358 = s32[299]{0} select(pred[299]{0} %compare.366, s32[299]{0} %negate.61, s32[299]{0} %and.83), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %compare.365 = pred[299]{0} compare(s32[299]{0} %select.358, s32[299]{0} %broadcast.520), direction=NE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %and.82 = pred[299]{0} and(pred[299]{0} %compare.367, pred[299]{0} %compare.365), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/and" source_file="" source_line=58} %shift-right-logical.8 = s32[299]{0} shift-right-logical(s32[299]{0} %select.360, s32[299]{0} %broadcast.521), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %negate.60 = s32[299]{0} negate(s32[299]{0} %shift-right-logical.8), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %select.357 = s32[299]{0} select(pred[299]{0} %compare.366, s32[299]{0} %negate.60, s32[299]{0} %shift-right-logical.8), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %constant.231.clone.2.clone.8 = s32[] constant(-1) %broadcast.519 = s32[299]{0} broadcast(s32[] %constant.231.clone.2.clone.8), dimensions={} %add.295 = s32[299]{0} add(s32[299]{0} %select.357, s32[299]{0} %broadcast.519), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sub" source_file="" source_line=58} %select.355 = s32[299]{0} select(pred[299]{0} %and.82, s32[299]{0} %add.295, s32[299]{0} %select.357), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/jit(_where)/select_n" source_file="" source_line=58} %compare.364 = pred[299]{0} compare(s32[299]{0} %select.355, s32[299]{0} %broadcast.520), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/lt" source_file="" source_line=58} %constant.500 = s32[] constant(100000) %broadcast.518 = s32[299]{0} broadcast(s32[] %constant.500), dimensions={} %add.294 = s32[299]{0} add(s32[299]{0} %select.355, s32[299]{0} %broadcast.518), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/add" source_file="" source_line=58} %select.354 = s32[299]{0} select(pred[299]{0} %compare.364, s32[299]{0} %add.294, s32[299]{0} %select.355), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %bitcast.16 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.354), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} ROOT %gather.95 = f32[299,1]{1,0} gather(f32[100000]{0} %param_0.5, s32[299,1]{1,0} %bitcast.16), offset_dims={1}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/gather" source_file="" source_line=58} } %wide.region_1.101.clone.1.clone.1 (wide.arg_tuple.9: (s32[], s32[299], s32[299], f32[100000], s32[299])) -> (s32[], s32[299], s32[299], f32[100000], s32[299]) { %wide.arg_tuple.9 = (s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) parameter(0) %get-tuple-element.619 = s32[299]{0} get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.9), index=1 %copy.4 = s32[299]{0} copy(s32[299]{0} %get-tuple-element.619) %get-tuple-element.620 = s32[299]{0} get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.9), index=2 %copy.5 = s32[299]{0} copy(s32[299]{0} %get-tuple-element.620) %get-tuple-element.626 = f32[100000]{0} get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.9), index=3 %fusion.2 = f32[299,1]{1,0} fusion(f32[100000]{0} %get-tuple-element.626, s32[299]{0} %copy.4, s32[299]{0} %copy.5), kind=kLoop, calls=%fused_computation.2, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/gather" source_file="" source_line=58} %get-tuple-element.627 = s32[299]{0} get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.9), index=4 %fusion = s32[299]{0} fusion(s32[299]{0} %copy.5, s32[299]{0} %get-tuple-element.627, f32[299,1]{1,0} %fusion.2, s32[299]{0} %copy.4), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(_where)/select_n" source_file="" source_line=58} %fusion.1 = s32[299]{0} fusion(s32[299]{0} %copy.4, s32[299]{0} %get-tuple-element.627, f32[299,1]{1,0} %fusion.2, s32[299]{0} %copy.5), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(_where)/select_n" source_file="" source_line=58} %constant.478 = s32[] constant(1) %get-tuple-element.618 = s32[] get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.9), index=0 %add.285 = s32[] add(s32[] %get-tuple-element.618, s32[] %constant.478), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/add" source_file="" source_line=58} ROOT %tuple.59 = (s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) tuple(s32[] %add.285, s32[299]{0} %fusion.1, s32[299]{0} %fusion, f32[100000]{0} %get-tuple-element.626, s32[299]{0} %get-tuple-element.627) } %wide.region_2.114.clone.1.clone.1 (wide.arg_tuple.8: (s32[], s32[299], s32[299], f32[100000], s32[299])) -> pred[] { %constant.477 = s32[] constant(17) %wide.arg_tuple.8 = (s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) parameter(0) %get-tuple-element.599 = s32[] get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.8), index=0 ROOT %compare.336 = pred[] compare(s32[] %get-tuple-element.599, s32[] %constant.477), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/cond/lt" source_file="" source_line=58} } %fused_computation.3 (param_0.6: f32[300,100000], param_1.17: s32[], param_2.22: f32[299], param_3.34: f32[299], param_4.46: f32[299], param_5.41: pred[299], param_6.42: f32[299], param_7.25: f32[299], param_8.24: f32[299], param_9.13: f32[299], param_10.15: f32[299], param_11.18: f32[299], param_12.18: f32[299], param_13.20: f32[299], param_14.18: f32[100000], param_15.20: f32[100000], param_16.24: f32[299], param_17.25: pred[]) -> f32[300,100000] { %param_0.6 = f32[300,100000]{1,0} parameter(0) %param_17.25 = pred[] parameter(17) %broadcast.533 = pred[300,1]{1,0} broadcast(pred[] %param_17.25), dimensions={} %param_15.20 = f32[100000]{0} parameter(15) %param_1.17 = s32[] parameter(1) %dynamic-slice.24 = f32[1]{0} dynamic-slice(f32[100000]{0} %param_15.20, s32[] %param_1.17), dynamic_slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %bitcast.22 = f32[] bitcast(f32[1]{0} %dynamic-slice.24), metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %broadcast.532 = f32[299]{0} broadcast(f32[] %bitcast.22), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %param_16.24 = f32[299]{0} parameter(16) %add.304 = f32[299]{0} add(f32[299]{0} %broadcast.532, f32[299]{0} %param_16.24), metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %slice.62 = f32[1]{0} slice(f32[100000]{0} %param_15.20), slice={[0:1]}, metadata={op_name="jit(main1)/jit(main)/while/body/slice" source_file="" source_line=62} %bitcast.21 = f32[] bitcast(f32[1]{0} %slice.62), metadata={op_name="jit(main1)/jit(main)/while/body/slice" source_file="" source_line=62} %broadcast.531 = f32[299]{0} broadcast(f32[] %bitcast.21), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=62} %compare.374 = pred[299]{0} compare(f32[299]{0} %add.304, f32[299]{0} %broadcast.531), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=62} %constant.229..sunk.4 = f32[299]{0} constant({...}), metadata={op_name="jit(main1)/jit(main)/while/body/exp" source_file="" source_line=65} %constant.506 = f32[] constant(0) %broadcast.530 = f32[299]{0} broadcast(f32[] %constant.506), dimensions={} %select.369 = f32[299]{0} select(pred[299]{0} %compare.374, f32[299]{0} %constant.229..sunk.4, f32[299]{0} %broadcast.530), metadata={op_name="jit(main1)/jit(main)/while/body/mul" source_file="" source_line=65} %constant.509 = s32[] constant(1) %broadcast.529 = s32[299]{0} broadcast(s32[] %constant.509), dimensions={} %convert.58 = s32[299]{0} convert(pred[299]{0} %compare.374), metadata={op_name="jit(main1)/jit(main)/while/body/convert_element_type" source_file="" source_line=66} %subtract.75 = s32[299]{0} subtract(s32[299]{0} %broadcast.529, s32[299]{0} %convert.58), metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=66} %convert.57 = f32[299]{0} convert(s32[299]{0} %subtract.75), metadata={op_name="jit(main1)/jit(main)/while/body/convert_element_type" source_file="" source_line=66} %param_8.24 = f32[299]{0} parameter(8) %constant.505 = f32[] constant(1) %broadcast.528 = f32[299]{0} broadcast(f32[] %constant.505), dimensions={} %slice.61 = f32[1]{0} slice(f32[100000]{0} %param_15.20), slice={[99999:100000]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.20 = f32[] bitcast(f32[1]{0} %slice.61), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.527 = f32[299]{0} broadcast(f32[] %bitcast.20), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gt" source_file="" source_line=58} %compare.373 = pred[299]{0} compare(f32[299]{0} %add.304, f32[299]{0} %broadcast.527), direction=GT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gt" source_file="" source_line=58} %param_14.18 = f32[100000]{0} parameter(14) %slice.60 = f32[1]{0} slice(f32[100000]{0} %param_14.18), slice={[99999:100000]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.19 = f32[] bitcast(f32[1]{0} %slice.60), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.526 = f32[299]{0} broadcast(f32[] %bitcast.19), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %slice.59 = f32[1]{0} slice(f32[100000]{0} %param_14.18), slice={[0:1]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/slice" source_file="" source_line=58} %bitcast.18 = f32[] bitcast(f32[1]{0} %slice.59), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/slice" source_file="" source_line=58} %broadcast.525 = f32[299]{0} broadcast(f32[] %bitcast.18), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %param_12.18 = f32[299]{0} parameter(12) %param_13.20 = f32[299]{0} parameter(13) %subtract.74 = f32[299]{0} subtract(f32[299]{0} %param_12.18, f32[299]{0} %param_13.20), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %abs.8 = f32[299]{0} abs(f32[299]{0} %subtract.74), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/abs" source_file="" source_line=58} %constant.504 = f32[] constant(1.42108547e-14) %broadcast.522 = f32[299]{0} broadcast(f32[] %constant.504), dimensions={} %compare.372 = pred[299]{0} compare(f32[299]{0} %abs.8, f32[299]{0} %broadcast.522), direction=LE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/le" source_file="" source_line=58} %param_9.13 = f32[299]{0} parameter(9) %param_10.15 = f32[299]{0} parameter(10) %param_11.18 = f32[299]{0} parameter(11) %subtract.73 = f32[299]{0} subtract(f32[299]{0} %param_11.18, f32[299]{0} %param_9.13), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %multiply.102 = f32[299]{0} multiply(f32[299]{0} %param_10.15, f32[299]{0} %subtract.73), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/mul" source_file="" source_line=58} %add.303 = f32[299]{0} add(f32[299]{0} %param_9.13, f32[299]{0} %multiply.102), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/add" source_file="" source_line=58} %select.368 = f32[299]{0} select(pred[299]{0} %compare.372, f32[299]{0} %param_9.13, f32[299]{0} %add.303), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.367 = f32[299]{0} select(pred[299]{0} %compare.374, f32[299]{0} %broadcast.525, f32[299]{0} %select.368), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.366 = f32[299]{0} select(pred[299]{0} %compare.373, f32[299]{0} %broadcast.526, f32[299]{0} %select.367), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %floor.8 = f32[299]{0} floor(f32[299]{0} %select.366), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/floor" source_file="" source_line=60} %subtract.72 = f32[299]{0} subtract(f32[299]{0} %select.366, f32[299]{0} %floor.8), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/sub" source_file="" source_line=60} %subtract.71 = f32[299]{0} subtract(f32[299]{0} %broadcast.528, f32[299]{0} %subtract.72), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/sub" source_file="" source_line=60} %multiply.101 = f32[299]{0} multiply(f32[299]{0} %param_8.24, f32[299]{0} %subtract.71), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/mul" source_file="" source_line=60} %param_5.41 = pred[299]{0} parameter(5) %convert.56 = s32[299]{0} convert(f32[299]{0} %floor.8), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/convert_element_type" source_file="" source_line=60} %constant.508 = s32[] constant(0) %broadcast.524 = s32[299]{0} broadcast(s32[] %constant.508), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.371 = pred[299]{0} compare(s32[299]{0} %convert.56, s32[299]{0} %broadcast.524), direction=GE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/ge" source_file="" source_line=60} %constant.507 = s32[] constant(100000) %broadcast.523 = s32[299]{0} broadcast(s32[] %constant.507), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.370 = pred[299]{0} compare(s32[299]{0} %convert.56, s32[299]{0} %broadcast.523), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/lt" source_file="" source_line=60} %and.89 = pred[299]{0} and(pred[299]{0} %compare.371, pred[299]{0} %compare.370), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/and" source_file="" source_line=60} %and.88 = pred[299]{0} and(pred[299]{0} %param_5.41, pred[299]{0} %and.89), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/and" source_file="" source_line=60} %param_7.25 = f32[299]{0} parameter(7) %select.365 = f32[299]{0} select(pred[299]{0} %and.88, f32[299]{0} %param_7.25, f32[299]{0} %broadcast.530), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/jit(_where)/select_n" source_file="" source_line=60} %multiply.100 = f32[299]{0} multiply(f32[299]{0} %multiply.101, f32[299]{0} %select.365), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/mul" source_file="" source_line=60} %param_6.42 = f32[299]{0} parameter(6) %multiply.99 = f32[299]{0} multiply(f32[299]{0} %param_6.42, f32[299]{0} %subtract.72), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/mul" source_file="" source_line=60} %add.302 = s32[299]{0} add(s32[299]{0} %convert.56, s32[299]{0} %broadcast.529), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %compare.369 = pred[299]{0} compare(s32[299]{0} %add.302, s32[299]{0} %broadcast.524), direction=GE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/ge" source_file="" source_line=60} %compare.368 = pred[299]{0} compare(s32[299]{0} %add.302, s32[299]{0} %broadcast.523), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/lt" source_file="" source_line=60} %and.87 = pred[299]{0} and(pred[299]{0} %compare.369, pred[299]{0} %compare.368), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/and" source_file="" source_line=60} %and.86 = pred[299]{0} and(pred[299]{0} %param_5.41, pred[299]{0} %and.87), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/and" source_file="" source_line=60} %param_4.46 = f32[299]{0} parameter(4) %select.364 = f32[299]{0} select(pred[299]{0} %and.86, f32[299]{0} %param_4.46, f32[299]{0} %broadcast.530), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/jit(_where)/select_n" source_file="" source_line=60} %multiply.98 = f32[299]{0} multiply(f32[299]{0} %multiply.99, f32[299]{0} %select.364), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/mul" source_file="" source_line=60} %add.300 = f32[299]{0} add(f32[299]{0} %multiply.100, f32[299]{0} %multiply.98), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %multiply.97 = f32[299]{0} multiply(f32[299]{0} %broadcast.530, f32[299]{0} %subtract.71), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/mul" source_file="" source_line=60} %constant.251..sunk.4 = pred[299]{0} constant({...}), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/and" source_file="" source_line=60} %and.85 = pred[299]{0} and(pred[299]{0} %constant.251..sunk.4, pred[299]{0} %and.89), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/and" source_file="" source_line=60} %param_3.34 = f32[299]{0} parameter(3) %select.363 = f32[299]{0} select(pred[299]{0} %and.85, f32[299]{0} %param_3.34, f32[299]{0} %broadcast.530), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/jit(_where)/select_n" source_file="" source_line=60} %multiply.96 = f32[299]{0} multiply(f32[299]{0} %multiply.97, f32[299]{0} %select.363), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/mul" source_file="" source_line=60} %add.299 = f32[299]{0} add(f32[299]{0} %add.300, f32[299]{0} %multiply.96), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %multiply.95 = f32[299]{0} multiply(f32[299]{0} %broadcast.530, f32[299]{0} %subtract.72), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/mul" source_file="" source_line=60} %and.84 = pred[299]{0} and(pred[299]{0} %constant.251..sunk.4, pred[299]{0} %and.87), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/and" source_file="" source_line=60} %param_2.22 = f32[299]{0} parameter(2) %select.362 = f32[299]{0} select(pred[299]{0} %and.84, f32[299]{0} %param_2.22, f32[299]{0} %broadcast.530), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/jit(_where)/select_n" source_file="" source_line=60} %multiply.94 = f32[299]{0} multiply(f32[299]{0} %multiply.95, f32[299]{0} %select.362), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/mul" source_file="" source_line=60} %add.298 = f32[299]{0} add(f32[299]{0} %add.299, f32[299]{0} %multiply.94), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %multiply.93 = f32[299]{0} multiply(f32[299]{0} %convert.57, f32[299]{0} %add.298), metadata={op_name="jit(main1)/jit(main)/while/body/mul" source_file="" source_line=66} %add.297 = f32[299]{0} add(f32[299]{0} %select.369, f32[299]{0} %multiply.93), metadata={op_name="jit(main1)/jit(main)/while/body/add" source_file="" source_line=65} %constant.503 = f32[1]{0} constant({2.90232038e-06}), metadata={op_name="jit(main1)/jit(main)/while/body/jit(append)/reshape" source_file="" source_line=67} %concatenate.49 = f32[300]{0} concatenate(f32[299]{0} %add.297, f32[1]{0} %constant.503), dimensions={0}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(append)/concatenate" source_file="" source_line=67} %bitcast.17 = f32[300,1]{1,0} bitcast(f32[300]{0} %concatenate.49), metadata={op_name="jit(main1)/jit(main)/while/body/jit(append)/concatenate" source_file="" source_line=67} %dynamic-slice.23 = f32[300,1]{1,0} dynamic-slice(f32[300,100000]{1,0} %param_0.6, s32[] %constant.508, s32[] %param_1.17), dynamic_slice_sizes={300,1} %select.361 = f32[300,1]{1,0} select(pred[300,1]{1,0} %broadcast.533, f32[300,1]{1,0} %bitcast.17, f32[300,1]{1,0} %dynamic-slice.23) ROOT %dynamic-update-slice.8 = f32[300,100000]{1,0} dynamic-update-slice(f32[300,100000]{1,0} %param_0.6, f32[300,1]{1,0} %select.361, s32[] %constant.508, s32[] %param_1.17), metadata={op_name="jit(main1)/jit(main)/while/body/scatter" source_file="" source_line=73} } %fused_computation.8 (param_0.11: f32[100000], param_1.62: s32[299]) -> f32[299] { %param_0.11 = f32[100000]{0} parameter(0) %constant.529 = s32[] constant(1) %broadcast.571 = s32[299]{0} broadcast(s32[] %constant.529), dimensions={} %param_1.62 = s32[299]{0} parameter(1) %constant.526 = s32[] constant(99999) %broadcast.568 = s32[299]{0} broadcast(s32[] %constant.526), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %clamp.7 = s32[299]{0} clamp(s32[299]{0} %broadcast.571, s32[299]{0} %param_1.62, s32[299]{0} %broadcast.568), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %constant.528 = s32[] constant(0) %broadcast.570 = s32[299]{0} broadcast(s32[] %constant.528), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.392 = pred[299]{0} compare(s32[299]{0} %clamp.7, s32[299]{0} %broadcast.570), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/lt" source_file="" source_line=58} %constant.527 = s32[] constant(100000) %broadcast.569 = s32[299]{0} broadcast(s32[] %constant.527), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %add.321 = s32[299]{0} add(s32[299]{0} %clamp.7, s32[299]{0} %broadcast.569), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/add" source_file="" source_line=58} %select.386 = s32[299]{0} select(pred[299]{0} %compare.392, s32[299]{0} %add.321, s32[299]{0} %clamp.7), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/select_n" source_file="" source_line=58} %bitcast.47 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.386), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/select_n" source_file="" source_line=58} ROOT %gather.100 = f32[299]{0} gather(f32[100000]{0} %param_0.11, s32[299,1]{1,0} %bitcast.47), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gather" source_file="" source_line=58} } %fused_computation.9 (param_0.16: f32[299], param_1.69: f32[299], param_2.80: f32[299], param_3.99: f32[100000], param_4.88: s32[]) -> f32[299] { %param_3.99 = f32[100000]{0} parameter(3) %param_4.88 = s32[] parameter(4) %dynamic-slice.29 = f32[1]{0} dynamic-slice(f32[100000]{0} %param_3.99, s32[] %param_4.88), dynamic_slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %bitcast.48 = f32[] bitcast(f32[1]{0} %dynamic-slice.29), metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %broadcast.574 = f32[299]{0} broadcast(f32[] %bitcast.48), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %param_2.80 = f32[299]{0} parameter(2) %add.322 = f32[299]{0} add(f32[299]{0} %broadcast.574, f32[299]{0} %param_2.80), metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %param_0.16 = f32[299]{0} parameter(0) %subtract.84 = f32[299]{0} subtract(f32[299]{0} %add.322, f32[299]{0} %param_0.16), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %param_1.69 = f32[299]{0} parameter(1) %subtract.85 = f32[299]{0} subtract(f32[299]{0} %param_1.69, f32[299]{0} %param_0.16), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %abs.13 = f32[299]{0} abs(f32[299]{0} %subtract.85), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/abs" source_file="" source_line=58} %constant.530 = f32[] constant(1.42108547e-14) %broadcast.572 = f32[299]{0} broadcast(f32[] %constant.530), dimensions={} %compare.393 = pred[299]{0} compare(f32[299]{0} %abs.13, f32[299]{0} %broadcast.572), direction=LE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/le" source_file="" source_line=58} %constant.531 = f32[] constant(1) %broadcast.573 = f32[299]{0} broadcast(f32[] %constant.531), dimensions={} %select.387 = f32[299]{0} select(pred[299]{0} %compare.393, f32[299]{0} %broadcast.573, f32[299]{0} %subtract.85), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} ROOT %divide.14 = f32[299]{0} divide(f32[299]{0} %subtract.84, f32[299]{0} %select.387), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/div" source_file="" source_line=58} } %fused_computation.10 (param_0.17: f32[100000], param_1.78: s32[299]) -> f32[299] { %param_0.17 = f32[100000]{0} parameter(0) %constant.534 = s32[] constant(1) %broadcast.578 = s32[299]{0} broadcast(s32[] %constant.534), dimensions={} %param_1.78 = s32[299]{0} parameter(1) %constant.532 = s32[] constant(99999) %broadcast.576 = s32[299]{0} broadcast(s32[] %constant.532), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %clamp.8 = s32[299]{0} clamp(s32[299]{0} %broadcast.578, s32[299]{0} %param_1.78, s32[299]{0} %broadcast.576), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %constant.232.clone.9 = s32[] constant(-1) %broadcast.575 = s32[299]{0} broadcast(s32[] %constant.232.clone.9), dimensions={} %add.324 = s32[299]{0} add(s32[299]{0} %clamp.8, s32[299]{0} %broadcast.575), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %constant.533 = s32[] constant(0) %broadcast.577 = s32[299]{0} broadcast(s32[] %constant.533), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.394 = pred[299]{0} compare(s32[299]{0} %add.324, s32[299]{0} %broadcast.577), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/lt" source_file="" source_line=58} %add.323 = s32[299]{0} add(s32[299]{0} %clamp.8, s32[299]{0} %broadcast.576), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/add" source_file="" source_line=58} %select.388 = s32[299]{0} select(pred[299]{0} %compare.394, s32[299]{0} %add.323, s32[299]{0} %add.324), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/select_n" source_file="" source_line=58} %bitcast.49 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.388), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/select_n" source_file="" source_line=58} ROOT %gather.101 = f32[299]{0} gather(f32[100000]{0} %param_0.17, s32[299,1]{1,0} %bitcast.49), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gather" source_file="" source_line=58} } %fused_computation.11 (param_0.18: f32[100000], param_1.87: s32[299]) -> f32[299] { %param_0.18 = f32[100000]{0} parameter(0) %constant.537 = s32[] constant(1) %broadcast.582 = s32[299]{0} broadcast(s32[] %constant.537), dimensions={} %param_1.87 = s32[299]{0} parameter(1) %constant.535 = s32[] constant(99999) %broadcast.580 = s32[299]{0} broadcast(s32[] %constant.535), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %clamp.9 = s32[299]{0} clamp(s32[299]{0} %broadcast.582, s32[299]{0} %param_1.87, s32[299]{0} %broadcast.580), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %constant.232.clone.10 = s32[] constant(-1) %broadcast.579 = s32[299]{0} broadcast(s32[] %constant.232.clone.10), dimensions={} %add.327 = s32[299]{0} add(s32[299]{0} %clamp.9, s32[299]{0} %broadcast.579), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %constant.536 = s32[] constant(0) %broadcast.581 = s32[299]{0} broadcast(s32[] %constant.536), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.395 = pred[299]{0} compare(s32[299]{0} %add.327, s32[299]{0} %broadcast.581), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/lt" source_file="" source_line=58} %add.326 = s32[299]{0} add(s32[299]{0} %clamp.9, s32[299]{0} %broadcast.580), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/add" source_file="" source_line=58} %select.389 = s32[299]{0} select(pred[299]{0} %compare.395, s32[299]{0} %add.326, s32[299]{0} %add.327), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/select_n" source_file="" source_line=58} %bitcast.50 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.389), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/select_n" source_file="" source_line=58} ROOT %gather.102 = f32[299]{0} gather(f32[100000]{0} %param_0.18, s32[299,1]{1,0} %bitcast.50), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gather" source_file="" source_line=58} } %fused_computation.12 (param_0.19: f32[100000], param_1.95: s32[299]) -> f32[299] { %param_0.19 = f32[100000]{0} parameter(0) %constant.541 = s32[] constant(1) %broadcast.586 = s32[299]{0} broadcast(s32[] %constant.541), dimensions={} %param_1.95 = s32[299]{0} parameter(1) %constant.538 = s32[] constant(99999) %broadcast.583 = s32[299]{0} broadcast(s32[] %constant.538), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %clamp.10 = s32[299]{0} clamp(s32[299]{0} %broadcast.586, s32[299]{0} %param_1.95, s32[299]{0} %broadcast.583), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %constant.540 = s32[] constant(0) %broadcast.585 = s32[299]{0} broadcast(s32[] %constant.540), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.396 = pred[299]{0} compare(s32[299]{0} %clamp.10, s32[299]{0} %broadcast.585), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/lt" source_file="" source_line=58} %constant.539 = s32[] constant(100000) %broadcast.584 = s32[299]{0} broadcast(s32[] %constant.539), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %add.329 = s32[299]{0} add(s32[299]{0} %clamp.10, s32[299]{0} %broadcast.584), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/add" source_file="" source_line=58} %select.390 = s32[299]{0} select(pred[299]{0} %compare.396, s32[299]{0} %add.329, s32[299]{0} %clamp.10), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/select_n" source_file="" source_line=58} %bitcast.51 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.390), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/select_n" source_file="" source_line=58} ROOT %gather.103 = f32[299]{0} gather(f32[100000]{0} %param_0.19, s32[299,1]{1,0} %bitcast.51), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gather" source_file="" source_line=58} } %fused_computation.13 (param_0.26: f32[299], param_1.106: f32[100000], param_2.120: s32[]) -> s32[299] { %param_1.106 = f32[100000]{0} parameter(1) %param_2.120 = s32[] parameter(2) %dynamic-slice.30 = f32[1]{0} dynamic-slice(f32[100000]{0} %param_1.106, s32[] %param_2.120), dynamic_slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %bitcast.52 = f32[] bitcast(f32[1]{0} %dynamic-slice.30), metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %broadcast.591 = f32[299]{0} broadcast(f32[] %bitcast.52), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %param_0.26 = f32[299]{0} parameter(0) %add.330 = f32[299]{0} add(f32[299]{0} %broadcast.591, f32[299]{0} %param_0.26), metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %compare.399 = pred[299]{0} compare(f32[299]{0} %add.330, f32[299]{0} %add.330), direction=NE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/ne" source_file="" source_line=58} %constant.543 = f32[] constant(nan) %broadcast.588 = f32[299]{0} broadcast(f32[] %constant.543), dimensions={} %constant.544 = f32[] constant(0) %broadcast.590 = f32[299]{0} broadcast(f32[] %constant.544), dimensions={} %compare.398 = pred[299]{0} compare(f32[299]{0} %add.330, f32[299]{0} %broadcast.590), direction=EQ, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/eq" source_file="" source_line=58} %select.394 = f32[299]{0} select(pred[299]{0} %compare.398, f32[299]{0} %broadcast.590, f32[299]{0} %add.330), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %select.393 = f32[299]{0} select(pred[299]{0} %compare.399, f32[299]{0} %broadcast.588, f32[299]{0} %select.394), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %bitcast-convert.19 = s32[299]{0} bitcast-convert(f32[299]{0} %select.393) %constant.545 = s32[] constant(0) %broadcast.589 = s32[299]{0} broadcast(s32[] %constant.545), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.397 = pred[299]{0} compare(s32[299]{0} %bitcast-convert.19, s32[299]{0} %broadcast.589), direction=LT %constant.542 = s32[] constant(2147483647) %broadcast.587 = s32[299]{0} broadcast(s32[] %constant.542), dimensions={} %xor.19 = s32[299]{0} xor(s32[299]{0} %broadcast.587, s32[299]{0} %bitcast-convert.19) ROOT %select.392 = s32[299]{0} select(pred[299]{0} %compare.397, s32[299]{0} %xor.19, s32[299]{0} %bitcast-convert.19) } %and.reduce_sub_computation (lhs: pred[], rhs: pred[]) -> pred[] { %lhs = pred[] parameter(0) %rhs = pred[] parameter(1) ROOT %and.18 = pred[] and(pred[] %lhs, pred[] %rhs) } %fused_computation.14 (param_0.30: s32[2], param_1.111: s32[2]) -> pred[] { %constant.547 = s32[] constant(0) %broadcast.592 = s32[2]{0} broadcast(s32[] %constant.547), dimensions={} %param_1.111 = s32[2]{0} parameter(1) %compare.401 = pred[2]{0} compare(s32[2]{0} %broadcast.592, s32[2]{0} %param_1.111), direction=LE %param_0.30 = s32[2]{0} parameter(0) %compare.400 = pred[2]{0} compare(s32[2]{0} %param_0.30, s32[2]{0} %param_1.111), direction=GE %and.90 = pred[2]{0} and(pred[2]{0} %compare.401, pred[2]{0} %compare.400) %constant.546 = pred[] constant(true) ROOT %reduce.7 = pred[] reduce(pred[2]{0} %and.90, pred[] %constant.546), dimensions={0}, to_apply=%and.reduce_sub_computation } %fused_computation.15 (param_0.32: s32[]) -> s32[2] { %constant.548 = s32[1]{0} constant({0}) %param_0.32 = s32[] parameter(0) %bitcast.53 = s32[1]{0} bitcast(s32[] %param_0.32), metadata={op_name="jit(main1)/jit(main)/while/body/select_n" source_file="" source_line=73} ROOT %concatenate.54 = s32[2]{0} concatenate(s32[1]{0} %constant.548, s32[1]{0} %bitcast.53), dimensions={0} } %fused_computation.16 (param_0.37: s32[100000], param_1.119: s32[]) -> s32[] { %param_0.37 = s32[100000]{0} parameter(0) %param_1.119 = s32[] parameter(1) %constant.550 = s32[] constant(0) %compare.403 = pred[] compare(s32[] %param_1.119, s32[] %constant.550), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=26} %constant.549 = s32[] constant(100000) %add.332 = s32[] add(s32[] %param_1.119, s32[] %constant.549), metadata={op_name="jit(main1)/jit(main)/while/body/add" source_file="" source_line=26} %select.396 = s32[] select(pred[] %compare.403, s32[] %add.332, s32[] %param_1.119), metadata={op_name="jit(main1)/jit(main)/while/body/select_n" source_file="" source_line=26} %dynamic-slice.31 = s32[1]{0} dynamic-slice(s32[100000]{0} %param_0.37, s32[] %select.396), dynamic_slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=26} %bitcast.54 = s32[] bitcast(s32[1]{0} %dynamic-slice.31), metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=26} %compare.402 = pred[] compare(s32[] %bitcast.54, s32[] %constant.550), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=73} %add.331 = s32[] add(s32[] %bitcast.54, s32[] %constant.549), metadata={op_name="jit(main1)/jit(main)/while/body/add" source_file="" source_line=73} ROOT %select.395 = s32[] select(pred[] %compare.402, s32[] %add.331, s32[] %bitcast.54), metadata={op_name="jit(main1)/jit(main)/while/body/select_n" source_file="" source_line=73} } %fused_computation.4.clone (param_0.46: f32[300,100000], param_1.128: f32[299], param_2.136: f32[299], param_3.134: f32[299], param_4.110: f32[299], param_5.84: f32[299], param_6.77: f32[100000], param_7.68: f32[100000], param_8.67: f32[299], param_9.66: s32[]) -> f32[299] { %param_0.46 = f32[300,100000]{1,0} parameter(0) %constant.556 = s32[299,1]{1,0} constant({...}), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/broadcast_in_dim" source_file="" source_line=60} %param_7.68 = f32[100000]{0} parameter(7) %param_9.66 = s32[] parameter(9) %dynamic-slice.32 = f32[1]{0} dynamic-slice(f32[100000]{0} %param_7.68, s32[] %param_9.66), dynamic_slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %bitcast.60 = f32[] bitcast(f32[1]{0} %dynamic-slice.32), metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %broadcast.605 = f32[299]{0} broadcast(f32[] %bitcast.60), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %param_8.67 = f32[299]{0} parameter(8) %add.337 = f32[299]{0} add(f32[299]{0} %broadcast.605, f32[299]{0} %param_8.67), metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %slice.81 = f32[1]{0} slice(f32[100000]{0} %param_7.68), slice={[99999:100000]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.58 = f32[] bitcast(f32[1]{0} %slice.81), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.602 = f32[299]{0} broadcast(f32[] %bitcast.58), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gt" source_file="" source_line=58} %compare.407 = pred[299]{0} compare(f32[299]{0} %add.337, f32[299]{0} %broadcast.602), direction=GT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gt" source_file="" source_line=58} %param_6.77 = f32[100000]{0} parameter(6) %slice.80 = f32[1]{0} slice(f32[100000]{0} %param_6.77), slice={[99999:100000]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.57 = f32[] bitcast(f32[1]{0} %slice.80), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.601 = f32[299]{0} broadcast(f32[] %bitcast.57), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %slice.82 = f32[1]{0} slice(f32[100000]{0} %param_7.68), slice={[0:1]}, metadata={op_name="jit(main1)/jit(main)/while/body/slice" source_file="" source_line=62} %bitcast.59 = f32[] bitcast(f32[1]{0} %slice.82), metadata={op_name="jit(main1)/jit(main)/while/body/slice" source_file="" source_line=62} %broadcast.604 = f32[299]{0} broadcast(f32[] %bitcast.59), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=62} %compare.408 = pred[299]{0} compare(f32[299]{0} %add.337, f32[299]{0} %broadcast.604), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=62} %slice.79 = f32[1]{0} slice(f32[100000]{0} %param_6.77), slice={[0:1]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/slice" source_file="" source_line=58} %bitcast.56 = f32[] bitcast(f32[1]{0} %slice.79), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/slice" source_file="" source_line=58} %broadcast.600 = f32[299]{0} broadcast(f32[] %bitcast.56), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %param_4.110 = f32[299]{0} parameter(4) %param_5.84 = f32[299]{0} parameter(5) %subtract.88 = f32[299]{0} subtract(f32[299]{0} %param_4.110, f32[299]{0} %param_5.84), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %abs.14 = f32[299]{0} abs(f32[299]{0} %subtract.88), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/abs" source_file="" source_line=58} %constant.557 = f32[] constant(1.42108547e-14) %broadcast.598 = f32[299]{0} broadcast(f32[] %constant.557), dimensions={} %compare.405 = pred[299]{0} compare(f32[299]{0} %abs.14, f32[299]{0} %broadcast.598), direction=LE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/le" source_file="" source_line=58} %param_1.128 = f32[299]{0} parameter(1) %param_2.136 = f32[299]{0} parameter(2) %param_3.134 = f32[299]{0} parameter(3) %subtract.87 = f32[299]{0} subtract(f32[299]{0} %param_3.134, f32[299]{0} %param_1.128), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %multiply.110 = f32[299]{0} multiply(f32[299]{0} %param_2.136, f32[299]{0} %subtract.87), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/mul" source_file="" source_line=58} %add.336 = f32[299]{0} add(f32[299]{0} %param_1.128, f32[299]{0} %multiply.110), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/add" source_file="" source_line=58} %select.400 = f32[299]{0} select(pred[299]{0} %compare.405, f32[299]{0} %param_1.128, f32[299]{0} %add.336), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.399 = f32[299]{0} select(pred[299]{0} %compare.408, f32[299]{0} %broadcast.600, f32[299]{0} %select.400), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.398 = f32[299]{0} select(pred[299]{0} %compare.407, f32[299]{0} %broadcast.601, f32[299]{0} %select.399), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %floor.13 = f32[299]{0} floor(f32[299]{0} %select.398), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/floor" source_file="" source_line=60} %convert.64 = s32[299]{0} convert(f32[299]{0} %floor.13), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/convert_element_type" source_file="" source_line=60} %constant.559 = s32[] constant(1) %broadcast.603 = s32[299]{0} broadcast(s32[] %constant.559), dimensions={} %add.335 = s32[299]{0} add(s32[299]{0} %convert.64, s32[299]{0} %broadcast.603), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %constant.558 = s32[] constant(0) %broadcast.599 = s32[299]{0} broadcast(s32[] %constant.558), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.404 = pred[299]{0} compare(s32[299]{0} %add.335, s32[299]{0} %broadcast.599), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/lt" source_file="" source_line=60} %constant.247.clone.10 = s32[] constant(100001) %broadcast.597 = s32[299]{0} broadcast(s32[] %constant.247.clone.10), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %add.334 = s32[299]{0} add(s32[299]{0} %convert.64, s32[299]{0} %broadcast.597), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %select.397 = s32[299]{0} select(pred[299]{0} %compare.404, s32[299]{0} %add.334, s32[299]{0} %add.335), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/select_n" source_file="" source_line=60} %bitcast.55 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.397), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/select_n" source_file="" source_line=60} %concatenate.56 = s32[299,2]{1,0} concatenate(s32[299,1]{1,0} %constant.556, s32[299,1]{1,0} %bitcast.55), dimensions={1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/concatenate" source_file="" source_line=60} ROOT %gather.104 = f32[299]{0} gather(f32[300,100000]{1,0} %param_0.46, s32[299,2]{1,0} %concatenate.56), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/gather" source_file="" source_line=60} } %parallel_fusion.4 (p.1: f32[300,100000], p.2: f32[299], p.3: f32[299], p.4: f32[299], p.5: f32[299], p.6: f32[299], p.7: f32[100000], p.8: f32[100000], p.9: f32[299], p.10: s32[]) -> f32[299] { %p.1 = f32[300,100000]{1,0} parameter(0) %p.10 = s32[] parameter(9) %p.2 = f32[299]{0} parameter(1) %p.3 = f32[299]{0} parameter(2) %p.4 = f32[299]{0} parameter(3) %p.5 = f32[299]{0} parameter(4) %p.6 = f32[299]{0} parameter(5) %p.7 = f32[100000]{0} parameter(6) %p.8 = f32[100000]{0} parameter(7) %p.9 = f32[299]{0} parameter(8) ROOT %fusion.4.clone = f32[299]{0} fusion(f32[300,100000]{1,0} %p.1, f32[299]{0} %p.2, f32[299]{0} %p.3, f32[299]{0} %p.4, f32[299]{0} %p.5, /*index=5*/f32[299]{0} %p.6, f32[100000]{0} %p.7, f32[100000]{0} %p.8, f32[299]{0} %p.9, s32[] %p.10), kind=kLoop, calls=%fused_computation.4.clone, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/gather" source_file="" source_line=60}, backend_config={"outer_dimension_partitions":["2"]} } %fused_computation.5.clone (param_0.47: f32[300,100000], param_1.129: f32[299], param_2.137: f32[299], param_3.135: f32[299], param_4.111: f32[299], param_5.85: f32[299], param_6.78: f32[100000], param_7.69: f32[100000], param_8.68: f32[299], param_9.67: s32[]) -> f32[299] { %param_0.47 = f32[300,100000]{1,0} parameter(0) %constant.560 = s32[299,1]{1,0} constant({...}), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/broadcast_in_dim" source_file="" source_line=60} %param_7.69 = f32[100000]{0} parameter(7) %param_9.67 = s32[] parameter(9) %dynamic-slice.33 = f32[1]{0} dynamic-slice(f32[100000]{0} %param_7.69, s32[] %param_9.67), dynamic_slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %bitcast.66 = f32[] bitcast(f32[1]{0} %dynamic-slice.33), metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %broadcast.613 = f32[299]{0} broadcast(f32[] %bitcast.66), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %param_8.68 = f32[299]{0} parameter(8) %add.341 = f32[299]{0} add(f32[299]{0} %broadcast.613, f32[299]{0} %param_8.68), metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %slice.85 = f32[1]{0} slice(f32[100000]{0} %param_7.69), slice={[99999:100000]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.64 = f32[] bitcast(f32[1]{0} %slice.85), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.611 = f32[299]{0} broadcast(f32[] %bitcast.64), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gt" source_file="" source_line=58} %compare.411 = pred[299]{0} compare(f32[299]{0} %add.341, f32[299]{0} %broadcast.611), direction=GT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gt" source_file="" source_line=58} %param_6.78 = f32[100000]{0} parameter(6) %slice.84 = f32[1]{0} slice(f32[100000]{0} %param_6.78), slice={[99999:100000]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.63 = f32[] bitcast(f32[1]{0} %slice.84), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.610 = f32[299]{0} broadcast(f32[] %bitcast.63), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %slice.86 = f32[1]{0} slice(f32[100000]{0} %param_7.69), slice={[0:1]}, metadata={op_name="jit(main1)/jit(main)/while/body/slice" source_file="" source_line=62} %bitcast.65 = f32[] bitcast(f32[1]{0} %slice.86), metadata={op_name="jit(main1)/jit(main)/while/body/slice" source_file="" source_line=62} %broadcast.612 = f32[299]{0} broadcast(f32[] %bitcast.65), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=62} %compare.412 = pred[299]{0} compare(f32[299]{0} %add.341, f32[299]{0} %broadcast.612), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=62} %slice.83 = f32[1]{0} slice(f32[100000]{0} %param_6.78), slice={[0:1]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/slice" source_file="" source_line=58} %bitcast.62 = f32[] bitcast(f32[1]{0} %slice.83), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/slice" source_file="" source_line=58} %broadcast.609 = f32[299]{0} broadcast(f32[] %bitcast.62), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %param_4.111 = f32[299]{0} parameter(4) %param_5.85 = f32[299]{0} parameter(5) %subtract.90 = f32[299]{0} subtract(f32[299]{0} %param_4.111, f32[299]{0} %param_5.85), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %abs.15 = f32[299]{0} abs(f32[299]{0} %subtract.90), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/abs" source_file="" source_line=58} %constant.561 = f32[] constant(1.42108547e-14) %broadcast.606 = f32[299]{0} broadcast(f32[] %constant.561), dimensions={} %compare.410 = pred[299]{0} compare(f32[299]{0} %abs.15, f32[299]{0} %broadcast.606), direction=LE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/le" source_file="" source_line=58} %param_1.129 = f32[299]{0} parameter(1) %param_2.137 = f32[299]{0} parameter(2) %param_3.135 = f32[299]{0} parameter(3) %subtract.89 = f32[299]{0} subtract(f32[299]{0} %param_3.135, f32[299]{0} %param_1.129), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %multiply.111 = f32[299]{0} multiply(f32[299]{0} %param_2.137, f32[299]{0} %subtract.89), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/mul" source_file="" source_line=58} %add.339 = f32[299]{0} add(f32[299]{0} %param_1.129, f32[299]{0} %multiply.111), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/add" source_file="" source_line=58} %select.404 = f32[299]{0} select(pred[299]{0} %compare.410, f32[299]{0} %param_1.129, f32[299]{0} %add.339), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.403 = f32[299]{0} select(pred[299]{0} %compare.412, f32[299]{0} %broadcast.609, f32[299]{0} %select.404), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.402 = f32[299]{0} select(pred[299]{0} %compare.411, f32[299]{0} %broadcast.610, f32[299]{0} %select.403), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %floor.14 = f32[299]{0} floor(f32[299]{0} %select.402), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/floor" source_file="" source_line=60} %convert.65 = s32[299]{0} convert(f32[299]{0} %floor.14), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/convert_element_type" source_file="" source_line=60} %constant.563 = s32[] constant(0) %broadcast.608 = s32[299]{0} broadcast(s32[] %constant.563), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.409 = pred[299]{0} compare(s32[299]{0} %convert.65, s32[299]{0} %broadcast.608), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/lt" source_file="" source_line=60} %constant.562 = s32[] constant(100000) %broadcast.607 = s32[299]{0} broadcast(s32[] %constant.562), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %add.338 = s32[299]{0} add(s32[299]{0} %convert.65, s32[299]{0} %broadcast.607), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %select.401 = s32[299]{0} select(pred[299]{0} %compare.409, s32[299]{0} %add.338, s32[299]{0} %convert.65), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/select_n" source_file="" source_line=60} %bitcast.61 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.401), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/select_n" source_file="" source_line=60} %concatenate.57 = s32[299,2]{1,0} concatenate(s32[299,1]{1,0} %constant.560, s32[299,1]{1,0} %bitcast.61), dimensions={1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/concatenate" source_file="" source_line=60} ROOT %gather.105 = f32[299]{0} gather(f32[300,100000]{1,0} %param_0.47, s32[299,2]{1,0} %concatenate.57), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/gather" source_file="" source_line=60} } %parallel_fusion.5 (p.11: f32[300,100000], p.12: f32[299], p.13: f32[299], p.14: f32[299], p.15: f32[299], p.16: f32[299], p.17: f32[100000], p.18: f32[100000], p.19: f32[299], p.20: s32[]) -> f32[299] { %p.11 = f32[300,100000]{1,0} parameter(0) %p.12 = f32[299]{0} parameter(1) %p.13 = f32[299]{0} parameter(2) %p.14 = f32[299]{0} parameter(3) %p.15 = f32[299]{0} parameter(4) %p.16 = f32[299]{0} parameter(5) %p.17 = f32[100000]{0} parameter(6) %p.18 = f32[100000]{0} parameter(7) %p.19 = f32[299]{0} parameter(8) %p.20 = s32[] parameter(9) ROOT %fusion.5.clone = f32[299]{0} fusion(f32[300,100000]{1,0} %p.11, f32[299]{0} %p.12, f32[299]{0} %p.13, f32[299]{0} %p.14, f32[299]{0} %p.15, /*index=5*/f32[299]{0} %p.16, f32[100000]{0} %p.17, f32[100000]{0} %p.18, f32[299]{0} %p.19, s32[] %p.20), kind=kLoop, calls=%fused_computation.5.clone, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/gather" source_file="" source_line=60}, backend_config={"outer_dimension_partitions":["2"]} } %fused_computation.6.clone (param_0.48: f32[300,100000], param_1.130: f32[299], param_2.138: f32[299], param_3.136: f32[299], param_4.112: f32[299], param_5.86: f32[299], param_6.79: f32[100000], param_7.70: f32[100000], param_8.69: f32[299], param_9.68: s32[]) -> f32[299] { %param_0.48 = f32[300,100000]{1,0} parameter(0) %constant.564 = s32[299,1]{1,0} constant({...}), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/broadcast_in_dim" source_file="" source_line=60} %param_7.70 = f32[100000]{0} parameter(7) %param_9.68 = s32[] parameter(9) %dynamic-slice.34 = f32[1]{0} dynamic-slice(f32[100000]{0} %param_7.70, s32[] %param_9.68), dynamic_slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %bitcast.72 = f32[] bitcast(f32[1]{0} %dynamic-slice.34), metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %broadcast.622 = f32[299]{0} broadcast(f32[] %bitcast.72), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %param_8.69 = f32[299]{0} parameter(8) %add.346 = f32[299]{0} add(f32[299]{0} %broadcast.622, f32[299]{0} %param_8.69), metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %slice.89 = f32[1]{0} slice(f32[100000]{0} %param_7.70), slice={[99999:100000]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.70 = f32[] bitcast(f32[1]{0} %slice.89), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.619 = f32[299]{0} broadcast(f32[] %bitcast.70), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gt" source_file="" source_line=58} %compare.415 = pred[299]{0} compare(f32[299]{0} %add.346, f32[299]{0} %broadcast.619), direction=GT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gt" source_file="" source_line=58} %param_6.79 = f32[100000]{0} parameter(6) %slice.88 = f32[1]{0} slice(f32[100000]{0} %param_6.79), slice={[99999:100000]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.69 = f32[] bitcast(f32[1]{0} %slice.88), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.618 = f32[299]{0} broadcast(f32[] %bitcast.69), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %slice.90 = f32[1]{0} slice(f32[100000]{0} %param_7.70), slice={[0:1]}, metadata={op_name="jit(main1)/jit(main)/while/body/slice" source_file="" source_line=62} %bitcast.71 = f32[] bitcast(f32[1]{0} %slice.90), metadata={op_name="jit(main1)/jit(main)/while/body/slice" source_file="" source_line=62} %broadcast.621 = f32[299]{0} broadcast(f32[] %bitcast.71), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=62} %compare.416 = pred[299]{0} compare(f32[299]{0} %add.346, f32[299]{0} %broadcast.621), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=62} %slice.87 = f32[1]{0} slice(f32[100000]{0} %param_6.79), slice={[0:1]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/slice" source_file="" source_line=58} %bitcast.68 = f32[] bitcast(f32[1]{0} %slice.87), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/slice" source_file="" source_line=58} %broadcast.617 = f32[299]{0} broadcast(f32[] %bitcast.68), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %param_4.112 = f32[299]{0} parameter(4) %param_5.86 = f32[299]{0} parameter(5) %subtract.92 = f32[299]{0} subtract(f32[299]{0} %param_4.112, f32[299]{0} %param_5.86), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %abs.16 = f32[299]{0} abs(f32[299]{0} %subtract.92), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/abs" source_file="" source_line=58} %constant.565 = f32[] constant(1.42108547e-14) %broadcast.615 = f32[299]{0} broadcast(f32[] %constant.565), dimensions={} %compare.414 = pred[299]{0} compare(f32[299]{0} %abs.16, f32[299]{0} %broadcast.615), direction=LE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/le" source_file="" source_line=58} %param_1.130 = f32[299]{0} parameter(1) %param_2.138 = f32[299]{0} parameter(2) %param_3.136 = f32[299]{0} parameter(3) %subtract.91 = f32[299]{0} subtract(f32[299]{0} %param_3.136, f32[299]{0} %param_1.130), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %multiply.112 = f32[299]{0} multiply(f32[299]{0} %param_2.138, f32[299]{0} %subtract.91), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/mul" source_file="" source_line=58} %add.345 = f32[299]{0} add(f32[299]{0} %param_1.130, f32[299]{0} %multiply.112), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/add" source_file="" source_line=58} %select.408 = f32[299]{0} select(pred[299]{0} %compare.414, f32[299]{0} %param_1.130, f32[299]{0} %add.345), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.407 = f32[299]{0} select(pred[299]{0} %compare.416, f32[299]{0} %broadcast.617, f32[299]{0} %select.408), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.406 = f32[299]{0} select(pred[299]{0} %compare.415, f32[299]{0} %broadcast.618, f32[299]{0} %select.407), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %floor.15 = f32[299]{0} floor(f32[299]{0} %select.406), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/floor" source_file="" source_line=60} %convert.66 = s32[299]{0} convert(f32[299]{0} %floor.15), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/convert_element_type" source_file="" source_line=60} %constant.567 = s32[] constant(1) %broadcast.620 = s32[299]{0} broadcast(s32[] %constant.567), dimensions={} %add.344 = s32[299]{0} add(s32[299]{0} %convert.66, s32[299]{0} %broadcast.620), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %constant.566 = s32[] constant(0) %broadcast.616 = s32[299]{0} broadcast(s32[] %constant.566), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.413 = pred[299]{0} compare(s32[299]{0} %add.344, s32[299]{0} %broadcast.616), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/lt" source_file="" source_line=60} %constant.247.clone.11 = s32[] constant(100001) %broadcast.614 = s32[299]{0} broadcast(s32[] %constant.247.clone.11), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %add.342 = s32[299]{0} add(s32[299]{0} %convert.66, s32[299]{0} %broadcast.614), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %select.405 = s32[299]{0} select(pred[299]{0} %compare.413, s32[299]{0} %add.342, s32[299]{0} %add.344), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/select_n" source_file="" source_line=60} %bitcast.67 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.405), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/select_n" source_file="" source_line=60} %concatenate.58 = s32[299,2]{1,0} concatenate(s32[299,1]{1,0} %constant.564, s32[299,1]{1,0} %bitcast.67), dimensions={1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/concatenate" source_file="" source_line=60} ROOT %gather.106 = f32[299]{0} gather(f32[300,100000]{1,0} %param_0.48, s32[299,2]{1,0} %concatenate.58), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/gather" source_file="" source_line=60} } %parallel_fusion.6 (p.21: f32[300,100000], p.22: f32[299], p.23: f32[299], p.24: f32[299], p.25: f32[299], p.26: f32[299], p.27: f32[100000], p.28: f32[100000], p.29: f32[299], p.30: s32[]) -> f32[299] { %p.21 = f32[300,100000]{1,0} parameter(0) %p.22 = f32[299]{0} parameter(1) %p.23 = f32[299]{0} parameter(2) %p.24 = f32[299]{0} parameter(3) %p.25 = f32[299]{0} parameter(4) %p.26 = f32[299]{0} parameter(5) %p.27 = f32[100000]{0} parameter(6) %p.28 = f32[100000]{0} parameter(7) %p.29 = f32[299]{0} parameter(8) %p.30 = s32[] parameter(9) ROOT %fusion.6.clone = f32[299]{0} fusion(f32[300,100000]{1,0} %p.21, f32[299]{0} %p.22, f32[299]{0} %p.23, f32[299]{0} %p.24, f32[299]{0} %p.25, /*index=5*/f32[299]{0} %p.26, f32[100000]{0} %p.27, f32[100000]{0} %p.28, f32[299]{0} %p.29, s32[] %p.30), kind=kLoop, calls=%fused_computation.6.clone, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/gather" source_file="" source_line=60}, backend_config={"outer_dimension_partitions":["2"]} } %fused_computation.7.clone (param_0.49: f32[300,100000], param_1.131: f32[299], param_2.139: f32[299], param_3.137: f32[299], param_4.113: f32[299], param_5.87: f32[299], param_6.80: f32[100000], param_7.71: f32[100000], param_8.70: f32[299], param_9.69: s32[]) -> f32[299] { %param_0.49 = f32[300,100000]{1,0} parameter(0) %constant.568 = s32[299,1]{1,0} constant({...}), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/broadcast_in_dim" source_file="" source_line=60} %param_7.71 = f32[100000]{0} parameter(7) %param_9.69 = s32[] parameter(9) %dynamic-slice.35 = f32[1]{0} dynamic-slice(f32[100000]{0} %param_7.71, s32[] %param_9.69), dynamic_slice_sizes={1}, metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %bitcast.78 = f32[] bitcast(f32[1]{0} %dynamic-slice.35), metadata={op_name="jit(main1)/jit(main)/while/body/dynamic_slice" source_file="" source_line=19} %broadcast.630 = f32[299]{0} broadcast(f32[] %bitcast.78), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %param_8.70 = f32[299]{0} parameter(8) %add.349 = f32[299]{0} add(f32[299]{0} %broadcast.630, f32[299]{0} %param_8.70), metadata={op_name="jit(main1)/jit(main)/while/body/sub" source_file="" source_line=56} %slice.93 = f32[1]{0} slice(f32[100000]{0} %param_7.71), slice={[99999:100000]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.76 = f32[] bitcast(f32[1]{0} %slice.93), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.628 = f32[299]{0} broadcast(f32[] %bitcast.76), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gt" source_file="" source_line=58} %compare.419 = pred[299]{0} compare(f32[299]{0} %add.349, f32[299]{0} %broadcast.628), direction=GT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gt" source_file="" source_line=58} %param_6.80 = f32[100000]{0} parameter(6) %slice.92 = f32[1]{0} slice(f32[100000]{0} %param_6.80), slice={[99999:100000]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.75 = f32[] bitcast(f32[1]{0} %slice.92), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.627 = f32[299]{0} broadcast(f32[] %bitcast.75), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %slice.94 = f32[1]{0} slice(f32[100000]{0} %param_7.71), slice={[0:1]}, metadata={op_name="jit(main1)/jit(main)/while/body/slice" source_file="" source_line=62} %bitcast.77 = f32[] bitcast(f32[1]{0} %slice.94), metadata={op_name="jit(main1)/jit(main)/while/body/slice" source_file="" source_line=62} %broadcast.629 = f32[299]{0} broadcast(f32[] %bitcast.77), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=62} %compare.420 = pred[299]{0} compare(f32[299]{0} %add.349, f32[299]{0} %broadcast.629), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/lt" source_file="" source_line=62} %slice.91 = f32[1]{0} slice(f32[100000]{0} %param_6.80), slice={[0:1]}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/slice" source_file="" source_line=58} %bitcast.74 = f32[] bitcast(f32[1]{0} %slice.91), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/slice" source_file="" source_line=58} %broadcast.626 = f32[299]{0} broadcast(f32[] %bitcast.74), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %param_4.113 = f32[299]{0} parameter(4) %param_5.87 = f32[299]{0} parameter(5) %subtract.94 = f32[299]{0} subtract(f32[299]{0} %param_4.113, f32[299]{0} %param_5.87), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %abs.17 = f32[299]{0} abs(f32[299]{0} %subtract.94), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/abs" source_file="" source_line=58} %constant.569 = f32[] constant(1.42108547e-14) %broadcast.623 = f32[299]{0} broadcast(f32[] %constant.569), dimensions={} %compare.418 = pred[299]{0} compare(f32[299]{0} %abs.17, f32[299]{0} %broadcast.623), direction=LE, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/le" source_file="" source_line=58} %param_1.131 = f32[299]{0} parameter(1) %param_2.139 = f32[299]{0} parameter(2) %param_3.137 = f32[299]{0} parameter(3) %subtract.93 = f32[299]{0} subtract(f32[299]{0} %param_3.137, f32[299]{0} %param_1.131), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/sub" source_file="" source_line=58} %multiply.113 = f32[299]{0} multiply(f32[299]{0} %param_2.139, f32[299]{0} %subtract.93), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/mul" source_file="" source_line=58} %add.348 = f32[299]{0} add(f32[299]{0} %param_1.131, f32[299]{0} %multiply.113), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/add" source_file="" source_line=58} %select.412 = f32[299]{0} select(pred[299]{0} %compare.418, f32[299]{0} %param_1.131, f32[299]{0} %add.348), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.411 = f32[299]{0} select(pred[299]{0} %compare.420, f32[299]{0} %broadcast.626, f32[299]{0} %select.412), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.410 = f32[299]{0} select(pred[299]{0} %compare.419, f32[299]{0} %broadcast.627, f32[299]{0} %select.411), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %floor.16 = f32[299]{0} floor(f32[299]{0} %select.410), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/floor" source_file="" source_line=60} %convert.67 = s32[299]{0} convert(f32[299]{0} %floor.16), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/convert_element_type" source_file="" source_line=60} %constant.571 = s32[] constant(0) %broadcast.625 = s32[299]{0} broadcast(s32[] %constant.571), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.417 = pred[299]{0} compare(s32[299]{0} %convert.67, s32[299]{0} %broadcast.625), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/lt" source_file="" source_line=60} %constant.570 = s32[] constant(100000) %broadcast.624 = s32[299]{0} broadcast(s32[] %constant.570), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %add.347 = s32[299]{0} add(s32[299]{0} %convert.67, s32[299]{0} %broadcast.624), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/add" source_file="" source_line=60} %select.409 = s32[299]{0} select(pred[299]{0} %compare.417, s32[299]{0} %add.347, s32[299]{0} %convert.67), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/select_n" source_file="" source_line=60} %bitcast.73 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.409), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/select_n" source_file="" source_line=60} %concatenate.59 = s32[299,2]{1,0} concatenate(s32[299,1]{1,0} %constant.568, s32[299,1]{1,0} %bitcast.73), dimensions={1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/concatenate" source_file="" source_line=60} ROOT %gather.107 = f32[299]{0} gather(f32[300,100000]{1,0} %param_0.49, s32[299,2]{1,0} %concatenate.59), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/gather" source_file="" source_line=60} } %parallel_fusion.7 (p.31: f32[300,100000], p.32: f32[299], p.33: f32[299], p.34: f32[299], p.35: f32[299], p.36: f32[299], p.37: f32[100000], p.38: f32[100000], p.39: f32[299], p.40: s32[]) -> f32[299] { %p.31 = f32[300,100000]{1,0} parameter(0) %p.32 = f32[299]{0} parameter(1) %p.33 = f32[299]{0} parameter(2) %p.34 = f32[299]{0} parameter(3) %p.35 = f32[299]{0} parameter(4) %p.36 = f32[299]{0} parameter(5) %p.37 = f32[100000]{0} parameter(6) %p.38 = f32[100000]{0} parameter(7) %p.39 = f32[299]{0} parameter(8) %p.40 = s32[] parameter(9) ROOT %fusion.7.clone = f32[299]{0} fusion(f32[300,100000]{1,0} %p.31, f32[299]{0} %p.32, f32[299]{0} %p.33, f32[299]{0} %p.34, f32[299]{0} %p.35, /*index=5*/f32[299]{0} %p.36, f32[100000]{0} %p.37, f32[100000]{0} %p.38, f32[299]{0} %p.39, s32[] %p.40), kind=kLoop, calls=%fused_computation.7.clone, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/gather" source_file="" source_line=60}, backend_config={"outer_dimension_partitions":["2"]} } %wide.wide.wide.region_0.427.clone.clone.clone (wide.wide.wide.arg_tuple.1: (s32[], f32[300,100000], s32[100000], f32[100000], f32[299], /*index=5*/f32[100000], pred[299], f32[299], s32[2], f32[299])) -> (s32[], f32[300,100000], s32[100000], f32[100000], f32[299], /*index=5*/f32[100000], pred[299], f32[299], s32[2], f32[299]) { %wide.wide.wide.arg_tuple.1 = (s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) parameter(0) %get-tuple-element.651 = f32[100000]{0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %wide.wide.wide.arg_tuple.1), index=3 %get-tuple-element.638 = s32[] get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %wide.wide.wide.arg_tuple.1), index=0 %copy.12 = s32[] copy(s32[] %get-tuple-element.638) %get-tuple-element.650 = s32[100000]{0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %wide.wide.wide.arg_tuple.1), index=2 %fusion.16 = s32[] fusion(s32[100000]{0} %get-tuple-element.650, s32[] %copy.12), kind=kLoop, calls=%fused_computation.16, metadata={op_name="jit(main1)/jit(main)/while/body/select_n" source_file="" source_line=73} %get-tuple-element.657 = f32[299]{0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %wide.wide.wide.arg_tuple.1), index=9 %fusion.13 = s32[299]{0} fusion(f32[299]{0} %get-tuple-element.657, f32[100000]{0} %get-tuple-element.651, s32[] %fusion.16), kind=kLoop, calls=%fused_computation.13 %constant.424 = s32[] constant(0) %copy.17 = s32[] copy(s32[] %constant.424) %broadcast.448 = s32[299]{0} broadcast(s32[] %copy.17), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %constant.427 = s32[] constant(100000) %broadcast.449 = s32[299]{0} broadcast(s32[] %constant.427), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %tuple.57 = (s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) tuple(s32[] %copy.17, s32[299]{0} %broadcast.448, s32[299]{0} %broadcast.449, f32[100000]{0} %get-tuple-element.651, s32[299]{0} %fusion.13) %while.22 = (s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) while((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %tuple.57), condition=%wide.region_2.114.clone.1.clone.1, body=%wide.region_1.101.clone.1.clone.1, backend_config={"known_trip_count":{"n":"17"}} %get-tuple-element.558 = s32[299]{0} get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %while.22), index=2, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/jit(searchsorted)/while" source_file="" source_line=58} %fusion.11 = f32[299]{0} fusion(f32[100000]{0} %get-tuple-element.651, s32[299]{0} %get-tuple-element.558), kind=kLoop, calls=%fused_computation.11, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gather" source_file="" source_line=58} %fusion.12 = f32[299]{0} fusion(f32[100000]{0} %get-tuple-element.651, s32[299]{0} %get-tuple-element.558), kind=kLoop, calls=%fused_computation.12, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gather" source_file="" source_line=58} %fusion.9 = f32[299]{0} fusion(f32[299]{0} %fusion.11, f32[299]{0} %fusion.12, f32[299]{0} %get-tuple-element.657, f32[100000]{0} %get-tuple-element.651, s32[] %fusion.16), kind=kLoop, calls=%fused_computation.9, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/div" source_file="" source_line=58} %get-tuple-element.653 = f32[100000]{0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %wide.wide.wide.arg_tuple.1), index=5 %fusion.10 = f32[299]{0} fusion(f32[100000]{0} %get-tuple-element.653, s32[299]{0} %get-tuple-element.558), kind=kLoop, calls=%fused_computation.10, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gather" source_file="" source_line=58} %fusion.8 = f32[299]{0} fusion(f32[100000]{0} %get-tuple-element.653, s32[299]{0} %get-tuple-element.558), kind=kLoop, calls=%fused_computation.8, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/gather" source_file="" source_line=58} %get-tuple-element.639 = f32[300,100000]{1,0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %wide.wide.wide.arg_tuple.1), index=1 %call.14 = f32[299]{0} call(f32[300,100000]{1,0} %get-tuple-element.639, f32[299]{0} %fusion.10, f32[299]{0} %fusion.9, f32[299]{0} %fusion.8, f32[299]{0} %fusion.12, /*index=5*/f32[299]{0} %fusion.11, f32[100000]{0} %get-tuple-element.653, f32[100000]{0} %get-tuple-element.651, f32[299]{0} %get-tuple-element.657, s32[] %fusion.16), to_apply=%parallel_fusion.7 %call.13 = f32[299]{0} call(f32[300,100000]{1,0} %get-tuple-element.639, f32[299]{0} %fusion.10, f32[299]{0} %fusion.9, f32[299]{0} %fusion.8, f32[299]{0} %fusion.12, /*index=5*/f32[299]{0} %fusion.11, f32[100000]{0} %get-tuple-element.653, f32[100000]{0} %get-tuple-element.651, f32[299]{0} %get-tuple-element.657, s32[] %fusion.16), to_apply=%parallel_fusion.6 %call.12 = f32[299]{0} call(f32[300,100000]{1,0} %get-tuple-element.639, f32[299]{0} %fusion.10, f32[299]{0} %fusion.9, f32[299]{0} %fusion.8, f32[299]{0} %fusion.12, /*index=5*/f32[299]{0} %fusion.11, f32[100000]{0} %get-tuple-element.653, f32[100000]{0} %get-tuple-element.651, f32[299]{0} %get-tuple-element.657, s32[] %fusion.16), to_apply=%parallel_fusion.5 %call.11 = f32[299]{0} call(f32[300,100000]{1,0} %get-tuple-element.639, f32[299]{0} %fusion.10, f32[299]{0} %fusion.9, f32[299]{0} %fusion.8, f32[299]{0} %fusion.12, /*index=5*/f32[299]{0} %fusion.11, f32[100000]{0} %get-tuple-element.653, f32[100000]{0} %get-tuple-element.651, f32[299]{0} %get-tuple-element.657, s32[] %fusion.16), to_apply=%parallel_fusion.4 %fusion.15 = s32[2]{0} fusion(s32[] %fusion.16), kind=kLoop, calls=%fused_computation.15 %get-tuple-element.656 = s32[2]{0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %wide.wide.wide.arg_tuple.1), index=8 %fusion.14 = pred[] fusion(s32[2]{0} %get-tuple-element.656, s32[2]{0} %fusion.15), kind=kLoop, calls=%fused_computation.14 %get-tuple-element.652 = f32[299]{0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %wide.wide.wide.arg_tuple.1), index=4 %get-tuple-element.654 = pred[299]{0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %wide.wide.wide.arg_tuple.1), index=6 %get-tuple-element.655 = f32[299]{0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %wide.wide.wide.arg_tuple.1), index=7 %fusion.3 = f32[300,100000]{1,0} fusion(f32[300,100000]{1,0} %get-tuple-element.639, s32[] %fusion.16, f32[299]{0} %call.11, f32[299]{0} %call.12, f32[299]{0} %call.13, /*index=5*/pred[299]{0} %get-tuple-element.654, f32[299]{0} %get-tuple-element.655, f32[299]{0} %call.14, f32[299]{0} %get-tuple-element.652, f32[299]{0} %fusion.10, /*index=10*/f32[299]{0} %fusion.9, f32[299]{0} %fusion.8, f32[299]{0} %fusion.12, f32[299]{0} %fusion.11, f32[100000]{0} %get-tuple-element.653, /*index=15*/f32[100000]{0} %get-tuple-element.651, f32[299]{0} %get-tuple-element.657, pred[] %fusion.14), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(main1)/jit(main)/while/body/scatter" source_file="" source_line=73} %constant.423 = s32[] constant(1) %add.254 = s32[] add(s32[] %copy.12, s32[] %constant.423), metadata={op_name="jit(main1)/jit(main)/while/body/add" source_file="" source_line=26} ROOT %tuple.62 = (s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) tuple(s32[] %add.254, f32[300,100000]{1,0} %fusion.3, s32[100000]{0} %get-tuple-element.650, f32[100000]{0} %get-tuple-element.651, f32[299]{0} %get-tuple-element.652, /*index=5*/f32[100000]{0} %get-tuple-element.653, pred[299]{0} %get-tuple-element.654, f32[299]{0} %get-tuple-element.655, s32[2]{0} %get-tuple-element.656, f32[299]{0} %get-tuple-element.657) } %wide.wide.wide.region_4.446.clone.clone.clone (wide.wide.wide.arg_tuple.0: (s32[], f32[300,100000], s32[100000], f32[100000], f32[299], /*index=5*/f32[100000], pred[299], f32[299], s32[2], f32[299])) -> pred[] { %constant.422 = s32[] constant(100000) %wide.wide.wide.arg_tuple.0 = (s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) parameter(0) %get-tuple-element.547 = s32[] get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %wide.wide.wide.arg_tuple.0), index=0 ROOT %compare.286 = pred[] compare(s32[] %get-tuple-element.547, s32[] %constant.422), direction=LT, metadata={op_name="jit(main1)/jit(main)/while/cond/lt" source_file="" source_line=26} } %fused_computation.17 () -> s32[2] { %iota.9 = s32[2]{0} iota(), iota_dimension=0 %constant.551 = s32[] constant(99999) %broadcast.593 = s32[2]{0} broadcast(s32[] %constant.551), dimensions={} ROOT %multiply.107 = s32[2]{0} multiply(s32[2]{0} %iota.9, s32[2]{0} %broadcast.593) } %fused_computation.18 () -> f32[100000] { %iota.10 = s32[100000]{0} iota(), iota_dimension=0, metadata={op_name="jit(main1)/jit(main)/while/body/iota" source_file="" source_line=58} ROOT %convert.63 = f32[100000]{0} convert(s32[100000]{0} %iota.10), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/convert_element_type" source_file="" source_line=58} } %fused_computation.19 () -> f32[100000] { %constant.555 = f32[] constant(0) %broadcast.596 = f32[99999]{0} broadcast(f32[] %constant.555), dimensions={}, metadata={op_name="jit(main1)/jit(main)/jit(_linspace)/mul" source_file="" source_line=12} %constant.554 = f32[] constant(1) %broadcast.595 = f32[99999]{0} broadcast(f32[] %constant.554), dimensions={} %iota.11 = f32[99999]{0} iota(), iota_dimension=0, metadata={op_name="jit(main1)/jit(main)/jit(_linspace)/iota" source_file="" source_line=12} %constant.553 = f32[] constant(1.00001e-05) %broadcast.594 = f32[99999]{0} broadcast(f32[] %constant.553), dimensions={} %multiply.109 = f32[99999]{0} multiply(f32[99999]{0} %iota.11, f32[99999]{0} %broadcast.594), metadata={op_name="jit(main1)/jit(main)/jit(_linspace)/div" source_file="" source_line=12} %subtract.86 = f32[99999]{0} subtract(f32[99999]{0} %broadcast.595, f32[99999]{0} %multiply.109), metadata={op_name="jit(main1)/jit(main)/jit(_linspace)/sub" source_file="" source_line=12} %multiply.108 = f32[99999]{0} multiply(f32[99999]{0} %broadcast.596, f32[99999]{0} %subtract.86), metadata={op_name="jit(main1)/jit(main)/jit(_linspace)/mul" source_file="" source_line=12} %add.333 = f32[99999]{0} add(f32[99999]{0} %multiply.108, f32[99999]{0} %multiply.109), metadata={op_name="jit(main1)/jit(main)/jit(_linspace)/add" source_file="" source_line=12} %constant.552 = f32[1]{0} constant({1}), metadata={op_name="jit(main1)/jit(main)/jit(_linspace)/broadcast_in_dim" source_file="" source_line=12} ROOT %concatenate.55 = f32[100000]{0} concatenate(f32[99999]{0} %add.333, f32[1]{0} %constant.552), dimensions={0}, metadata={op_name="jit(main1)/jit(main)/jit(_linspace)/concatenate" source_file="" source_line=12} } %parallel_broadcast.2 (p: f32[]) -> f32[300,100000] { %p = f32[] parameter(0) ROOT %broadcast.2.clone = f32[300,100000]{1,0} broadcast(f32[] %p), dimensions={}, backend_config={"outer_dimension_partitions":["2"]} } ENTRY %main.457 () -> f32[300,100000] { %constant.1 = f32[] constant(0) %call.10 = f32[300,100000]{1,0} call(f32[] %constant.1), to_apply=%parallel_broadcast.2 %fusion.18 = f32[100000]{0} fusion(), kind=kLoop, calls=%fused_computation.18, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_interp)/convert_element_type" source_file="" source_line=58} %fusion.19 = f32[100000]{0} fusion(), kind=kLoop, calls=%fused_computation.19, metadata={op_name="jit(main1)/jit(main)/jit(_linspace)/concatenate" source_file="" source_line=12} %iota.31 = s32[100000]{0} iota(), iota_dimension=0, metadata={op_name="jit(main1)/jit(main)/iota" source_file="" source_line=26} %constant.58 = f32[] constant(1) %broadcast.260 = f32[299]{0} broadcast(f32[] %constant.58), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/sub" source_file="" source_line=60} %constant.269 = pred[] constant(true), metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/and" source_file="" source_line=60} %broadcast.266 = pred[299]{0} broadcast(pred[] %constant.269), dimensions={}, metadata={op_name="jit(main1)/jit(main)/while/body/jit(_map_coordinates)/and" source_file="" source_line=60} %fusion.17 = s32[2]{0} fusion(), kind=kLoop, calls=%fused_computation.17 %constant.3 = s32[] constant(0) %copy.16 = s32[] copy(s32[] %constant.3) %constant.476 = f32[299]{0} constant({...}), metadata={op_name="jit(main1)/jit(main)/while/body/log" source_file="" source_line=56} %tuple.60 = (s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) tuple(s32[] %copy.16, f32[300,100000]{1,0} %call.10, s32[100000]{0} %iota.31, f32[100000]{0} %fusion.19, f32[299]{0} %broadcast.260, /*index=5*/f32[100000]{0} %fusion.18, pred[299]{0} %broadcast.266, f32[299]{0} %broadcast.260, s32[2]{0} %fusion.17, f32[299]{0} %constant.476) %while.20 = (s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) while((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %tuple.60), condition=%wide.wide.wide.region_4.446.clone.clone.clone, body=%wide.wide.wide.region_0.427.clone.clone.clone, backend_config={"known_trip_count":{"n":"100000"}} ROOT %get-tuple-element.456 = f32[300,100000]{1,0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, f32[299]{0}, /*index=5*/f32[100000]{0}, pred[299]{0}, f32[299]{0}, s32[2]{0}, f32[299]{0}) %while.20), index=1, metadata={op_name="jit(main1)/jit(main)/while" source_file="" source_line=26} } ```

And for case 2, the output is this:

Click to expand ``` HloModule jit_main2, is_scheduled=true, entry_computation_layout={()->f32[300,100000]{1,0}}, allow_spmd_sharding_propagation_to_output={true} %fused_computation (param_0.1: f32[300]) -> f32[300] { %param_0.1 = f32[300]{0} parameter(0) %negate.25 = f32[300]{0} negate(f32[300]{0} %param_0.1), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_0_fun/neg" source_file="" source_line=78} ROOT %exponential.1 = f32[300]{0} exponential(f32[300]{0} %negate.25), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_0_fun/exp" source_file="" source_line=78} } %region_1.33 (Arg_.34: f32[300]) -> (f32[300]) { %Arg_.34 = f32[300]{0} parameter(0), metadata={op_name="jit(main2)/jit(main)/while/body/closed_call"} %fusion = f32[300]{0} fusion(f32[300]{0} %Arg_.34), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_0_fun/exp" source_file="" source_line=78} ROOT %tuple.28 = (f32[300]{0}) tuple(f32[300]{0} %fusion) } %fused_computation.1 (param_0.4: s32[299], param_1.2: s32[299], param_2.3: f32[299,1], param_3.8: s32[299]) -> s32[299] { %param_1.2 = s32[299]{0} parameter(1) %param_2.3 = f32[299,1]{1,0} parameter(2) %bitcast.16 = f32[299]{0} bitcast(f32[299,1]{1,0} %param_2.3), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/gather" source_file="" source_line=58} %compare.105 = pred[299]{0} compare(f32[299]{0} %bitcast.16, f32[299]{0} %bitcast.16), direction=NE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/ne" source_file="" source_line=58} %constant.194 = f32[] constant(nan) %broadcast.176 = f32[299]{0} broadcast(f32[] %constant.194), dimensions={} %constant.193 = f32[] constant(0) %broadcast.175 = f32[299]{0} broadcast(f32[] %constant.193), dimensions={} %compare.104 = pred[299]{0} compare(f32[299]{0} %bitcast.16, f32[299]{0} %broadcast.175), direction=EQ, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/eq" source_file="" source_line=58} %select.110 = f32[299]{0} select(pred[299]{0} %compare.104, f32[299]{0} %broadcast.175, f32[299]{0} %bitcast.16), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %select.109 = f32[299]{0} select(pred[299]{0} %compare.105, f32[299]{0} %broadcast.176, f32[299]{0} %select.110), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %bitcast-convert.9 = s32[299]{0} bitcast-convert(f32[299]{0} %select.109) %constant.195 = s32[] constant(0) %broadcast.180 = s32[299]{0} broadcast(s32[] %constant.195), dimensions={} %compare.103 = pred[299]{0} compare(s32[299]{0} %bitcast-convert.9, s32[299]{0} %broadcast.180), direction=LT %constant.192 = s32[] constant(2147483647) %broadcast.174 = s32[299]{0} broadcast(s32[] %constant.192), dimensions={} %xor.9 = s32[299]{0} xor(s32[299]{0} %broadcast.174, s32[299]{0} %bitcast-convert.9) %select.108 = s32[299]{0} select(pred[299]{0} %compare.103, s32[299]{0} %xor.9, s32[299]{0} %bitcast-convert.9) %compare.102 = pred[299]{0} compare(s32[299]{0} %param_1.2, s32[299]{0} %select.108), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/lt_to" source_file="" source_line=58} %param_3.8 = s32[299]{0} parameter(3) %param_0.4 = s32[299]{0} parameter(0) %add.70 = s32[299]{0} add(s32[299]{0} %param_3.8, s32[299]{0} %param_0.4), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/add" source_file="" source_line=58} %sign.9 = s32[299]{0} sign(s32[299]{0} %add.70), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sign" source_file="" source_line=58} %constant.196 = s32[] constant(1) %broadcast.182 = s32[299]{0} broadcast(s32[] %constant.196), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %compare.108 = pred[299]{0} compare(s32[299]{0} %sign.9, s32[299]{0} %broadcast.182), direction=NE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %compare.107 = pred[299]{0} compare(s32[299]{0} %add.70, s32[299]{0} %broadcast.180), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.28 = s32[299]{0} negate(s32[299]{0} %add.70), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.114 = s32[299]{0} select(pred[299]{0} %compare.107, s32[299]{0} %negate.28, s32[299]{0} %add.70), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %and.22 = s32[299]{0} and(s32[299]{0} %select.114, s32[299]{0} %broadcast.182), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.27 = s32[299]{0} negate(s32[299]{0} %and.22), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.113 = s32[299]{0} select(pred[299]{0} %compare.107, s32[299]{0} %negate.27, s32[299]{0} %and.22), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %compare.106 = pred[299]{0} compare(s32[299]{0} %select.113, s32[299]{0} %broadcast.180), direction=NE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %and.21 = pred[299]{0} and(pred[299]{0} %compare.108, pred[299]{0} %compare.106), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/and" source_file="" source_line=58} %shift-right-logical.4 = s32[299]{0} shift-right-logical(s32[299]{0} %select.114, s32[299]{0} %broadcast.182), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %negate.26 = s32[299]{0} negate(s32[299]{0} %shift-right-logical.4), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %select.112 = s32[299]{0} select(pred[299]{0} %compare.107, s32[299]{0} %negate.26, s32[299]{0} %shift-right-logical.4), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %constant.117.clone.4 = s32[] constant(-1) %broadcast.178 = s32[299]{0} broadcast(s32[] %constant.117.clone.4), dimensions={} %add.69 = s32[299]{0} add(s32[299]{0} %select.112, s32[299]{0} %broadcast.178), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sub" source_file="" source_line=58} %select.111 = s32[299]{0} select(pred[299]{0} %and.21, s32[299]{0} %add.69, s32[299]{0} %select.112), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/jit(_where)/select_n" source_file="" source_line=58} ROOT %select.107 = s32[299]{0} select(pred[299]{0} %compare.102, s32[299]{0} %select.111, s32[299]{0} %param_0.4), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(_where)/select_n" source_file="" source_line=58} } %fused_computation.2 (param_0.6: s32[299], param_1.5: s32[299], param_2.7: f32[299,1], param_3.17: s32[299]) -> s32[299] { %param_1.5 = s32[299]{0} parameter(1) %param_2.7 = f32[299,1]{1,0} parameter(2) %bitcast.17 = f32[299]{0} bitcast(f32[299,1]{1,0} %param_2.7), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/gather" source_file="" source_line=58} %compare.112 = pred[299]{0} compare(f32[299]{0} %bitcast.17, f32[299]{0} %bitcast.17), direction=NE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/ne" source_file="" source_line=58} %constant.199 = f32[] constant(nan) %broadcast.186 = f32[299]{0} broadcast(f32[] %constant.199), dimensions={} %constant.198 = f32[] constant(0) %broadcast.185 = f32[299]{0} broadcast(f32[] %constant.198), dimensions={} %compare.111 = pred[299]{0} compare(f32[299]{0} %bitcast.17, f32[299]{0} %broadcast.185), direction=EQ, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/eq" source_file="" source_line=58} %select.118 = f32[299]{0} select(pred[299]{0} %compare.111, f32[299]{0} %broadcast.185, f32[299]{0} %bitcast.17), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %select.117 = f32[299]{0} select(pred[299]{0} %compare.112, f32[299]{0} %broadcast.186, f32[299]{0} %select.118), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %bitcast-convert.10 = s32[299]{0} bitcast-convert(f32[299]{0} %select.117) %constant.200 = s32[] constant(0) %broadcast.188 = s32[299]{0} broadcast(s32[] %constant.200), dimensions={} %compare.110 = pred[299]{0} compare(s32[299]{0} %bitcast-convert.10, s32[299]{0} %broadcast.188), direction=LT %constant.197 = s32[] constant(2147483647) %broadcast.184 = s32[299]{0} broadcast(s32[] %constant.197), dimensions={} %xor.10 = s32[299]{0} xor(s32[299]{0} %broadcast.184, s32[299]{0} %bitcast-convert.10) %select.116 = s32[299]{0} select(pred[299]{0} %compare.110, s32[299]{0} %xor.10, s32[299]{0} %bitcast-convert.10) %compare.109 = pred[299]{0} compare(s32[299]{0} %param_1.5, s32[299]{0} %select.116), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/lt_to" source_file="" source_line=58} %param_0.6 = s32[299]{0} parameter(0) %param_3.17 = s32[299]{0} parameter(3) %add.72 = s32[299]{0} add(s32[299]{0} %param_0.6, s32[299]{0} %param_3.17), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/add" source_file="" source_line=58} %sign.10 = s32[299]{0} sign(s32[299]{0} %add.72), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sign" source_file="" source_line=58} %constant.201 = s32[] constant(1) %broadcast.189 = s32[299]{0} broadcast(s32[] %constant.201), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %compare.115 = pred[299]{0} compare(s32[299]{0} %sign.10, s32[299]{0} %broadcast.189), direction=NE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %compare.114 = pred[299]{0} compare(s32[299]{0} %add.72, s32[299]{0} %broadcast.188), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.31 = s32[299]{0} negate(s32[299]{0} %add.72), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.122 = s32[299]{0} select(pred[299]{0} %compare.114, s32[299]{0} %negate.31, s32[299]{0} %add.72), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %and.24 = s32[299]{0} and(s32[299]{0} %select.122, s32[299]{0} %broadcast.189), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.30 = s32[299]{0} negate(s32[299]{0} %and.24), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.121 = s32[299]{0} select(pred[299]{0} %compare.114, s32[299]{0} %negate.30, s32[299]{0} %and.24), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %compare.113 = pred[299]{0} compare(s32[299]{0} %select.121, s32[299]{0} %broadcast.188), direction=NE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %and.23 = pred[299]{0} and(pred[299]{0} %compare.115, pred[299]{0} %compare.113), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/and" source_file="" source_line=58} %shift-right-logical.5 = s32[299]{0} shift-right-logical(s32[299]{0} %select.122, s32[299]{0} %broadcast.189), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %negate.29 = s32[299]{0} negate(s32[299]{0} %shift-right-logical.5), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %select.120 = s32[299]{0} select(pred[299]{0} %compare.114, s32[299]{0} %negate.29, s32[299]{0} %shift-right-logical.5), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %constant.117.clone.5 = s32[] constant(-1) %broadcast.187 = s32[299]{0} broadcast(s32[] %constant.117.clone.5), dimensions={} %add.71 = s32[299]{0} add(s32[299]{0} %select.120, s32[299]{0} %broadcast.187), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sub" source_file="" source_line=58} %select.119 = s32[299]{0} select(pred[299]{0} %and.23, s32[299]{0} %add.71, s32[299]{0} %select.120), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/jit(_where)/select_n" source_file="" source_line=58} ROOT %select.115 = s32[299]{0} select(pred[299]{0} %compare.109, s32[299]{0} %param_0.6, s32[299]{0} %select.119), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(_where)/select_n" source_file="" source_line=58} } %fused_computation.3 (param_0.7: f32[100000], param_1.14: s32[299], param_2.19: s32[299]) -> f32[299,1] { %param_0.7 = f32[100000]{0} parameter(0) %param_1.14 = s32[299]{0} parameter(1) %param_2.19 = s32[299]{0} parameter(2) %add.75 = s32[299]{0} add(s32[299]{0} %param_1.14, s32[299]{0} %param_2.19), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/add" source_file="" source_line=58} %sign.11 = s32[299]{0} sign(s32[299]{0} %add.75), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sign" source_file="" source_line=58} %constant.204 = s32[] constant(1) %broadcast.193 = s32[299]{0} broadcast(s32[] %constant.204), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %compare.119 = pred[299]{0} compare(s32[299]{0} %sign.11, s32[299]{0} %broadcast.193), direction=NE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %constant.203 = s32[] constant(0) %broadcast.192 = s32[299]{0} broadcast(s32[] %constant.203), dimensions={} %compare.118 = pred[299]{0} compare(s32[299]{0} %add.75, s32[299]{0} %broadcast.192), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.34 = s32[299]{0} negate(s32[299]{0} %add.75), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.127 = s32[299]{0} select(pred[299]{0} %compare.118, s32[299]{0} %negate.34, s32[299]{0} %add.75), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %and.26 = s32[299]{0} and(s32[299]{0} %select.127, s32[299]{0} %broadcast.193), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %negate.33 = s32[299]{0} negate(s32[299]{0} %and.26), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %select.126 = s32[299]{0} select(pred[299]{0} %compare.118, s32[299]{0} %negate.33, s32[299]{0} %and.26), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/rem" source_file="" source_line=58} %compare.117 = pred[299]{0} compare(s32[299]{0} %select.126, s32[299]{0} %broadcast.192), direction=NE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/ne" source_file="" source_line=58} %and.25 = pred[299]{0} and(pred[299]{0} %compare.119, pred[299]{0} %compare.117), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/and" source_file="" source_line=58} %shift-right-logical.6 = s32[299]{0} shift-right-logical(s32[299]{0} %select.127, s32[299]{0} %broadcast.193), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %negate.32 = s32[299]{0} negate(s32[299]{0} %shift-right-logical.6), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %select.125 = s32[299]{0} select(pred[299]{0} %compare.118, s32[299]{0} %negate.32, s32[299]{0} %shift-right-logical.6), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/div" source_file="" source_line=58} %constant.117.clone.6 = s32[] constant(-1) %broadcast.191 = s32[299]{0} broadcast(s32[] %constant.117.clone.6), dimensions={} %add.74 = s32[299]{0} add(s32[299]{0} %select.125, s32[299]{0} %broadcast.191), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/sub" source_file="" source_line=58} %select.124 = s32[299]{0} select(pred[299]{0} %and.25, s32[299]{0} %add.74, s32[299]{0} %select.125), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(floor_divide)/jit(_where)/select_n" source_file="" source_line=58} %compare.116 = pred[299]{0} compare(s32[299]{0} %select.124, s32[299]{0} %broadcast.192), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/lt" source_file="" source_line=58} %constant.202 = s32[] constant(100000) %broadcast.190 = s32[299]{0} broadcast(s32[] %constant.202), dimensions={} %add.73 = s32[299]{0} add(s32[299]{0} %select.124, s32[299]{0} %broadcast.190), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/add" source_file="" source_line=58} %select.123 = s32[299]{0} select(pred[299]{0} %compare.116, s32[299]{0} %add.73, s32[299]{0} %select.124), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %bitcast.18 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.123), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} ROOT %gather.15 = f32[299,1]{1,0} gather(f32[100000]{0} %param_0.7, s32[299,1]{1,0} %bitcast.18), offset_dims={1}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/gather" source_file="" source_line=58} } %wide.region_3.105.clone.clone (wide.arg_tuple.3: (s32[], s32[299], s32[299], f32[100000], s32[299])) -> (s32[], s32[299], s32[299], f32[100000], s32[299]) { %wide.arg_tuple.3 = (s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) parameter(0) %get-tuple-element.125 = s32[299]{0} get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.3), index=1 %copy.4 = s32[299]{0} copy(s32[299]{0} %get-tuple-element.125) %get-tuple-element.126 = s32[299]{0} get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.3), index=2 %copy.5 = s32[299]{0} copy(s32[299]{0} %get-tuple-element.126) %get-tuple-element.132 = f32[100000]{0} get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.3), index=3 %fusion.3 = f32[299,1]{1,0} fusion(f32[100000]{0} %get-tuple-element.132, s32[299]{0} %copy.4, s32[299]{0} %copy.5), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/gather" source_file="" source_line=58} %get-tuple-element.133 = s32[299]{0} get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.3), index=4 %fusion.1 = s32[299]{0} fusion(s32[299]{0} %copy.5, s32[299]{0} %get-tuple-element.133, f32[299,1]{1,0} %fusion.3, s32[299]{0} %copy.4), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(_where)/select_n" source_file="" source_line=58} %fusion.2 = s32[299]{0} fusion(s32[299]{0} %copy.4, s32[299]{0} %get-tuple-element.133, f32[299,1]{1,0} %fusion.3, s32[299]{0} %copy.5), kind=kLoop, calls=%fused_computation.2, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/jit(_where)/select_n" source_file="" source_line=58} %constant.172 = s32[] constant(1) %get-tuple-element.119 = s32[] get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.3), index=0 %add.65 = s32[] add(s32[] %get-tuple-element.119, s32[] %constant.172), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/add" source_file="" source_line=58} ROOT %tuple.27 = (s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) tuple(s32[] %add.65, s32[299]{0} %fusion.2, s32[299]{0} %fusion.1, f32[100000]{0} %get-tuple-element.132, s32[299]{0} %get-tuple-element.133) } %wide.region_4.118.clone.clone (wide.arg_tuple.2: (s32[], s32[299], s32[299], f32[100000], s32[299])) -> pred[] { %constant.171 = s32[] constant(17) %wide.arg_tuple.2 = (s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) parameter(0) %get-tuple-element.93 = s32[] get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %wide.arg_tuple.2), index=0 ROOT %compare.86 = pred[] compare(s32[] %get-tuple-element.93, s32[] %constant.171), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/cond/lt" source_file="" source_line=58} } %fused_computation.4 (param_0.9: f32[300], param_1.21: f32[300,100000], param_2.24: s32[299], param_3.37: f32[299], param_4.48: f32[299], param_5.47: f32[100000], param_6.52: s32[299], param_7.35: s32[], param_8.37: f32[299], param_9.30: f32[]) -> f32[300] { %param_9.30 = f32[] parameter(9) %broadcast.208 = f32[299]{0} broadcast(f32[] %param_9.30), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/sub" source_file="" source_line=56} %param_8.37 = f32[299]{0} parameter(8) %subtract.22 = f32[299]{0} subtract(f32[299]{0} %broadcast.208, f32[299]{0} %param_8.37), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/sub" source_file="" source_line=56} %param_5.47 = f32[100000]{0} parameter(5) %slice.16 = f32[1]{0} slice(f32[100000]{0} %param_5.47), slice={[0:1]}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/slice" source_file="" source_line=62} %bitcast.27 = f32[] bitcast(f32[1]{0} %slice.16), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/slice" source_file="" source_line=62} %broadcast.207 = f32[299]{0} broadcast(f32[] %bitcast.27), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/lt" source_file="" source_line=62} %compare.136 = pred[299]{0} compare(f32[299]{0} %subtract.22, f32[299]{0} %broadcast.207), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/lt" source_file="" source_line=62} %param_0.9 = f32[300]{0} parameter(0) %slice.17 = f32[299]{0} slice(f32[300]{0} %param_0.9), slice={[0:299]}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/slice" source_file="" source_line=56} %negate.37 = f32[299]{0} negate(f32[299]{0} %slice.17), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/neg" source_file="" source_line=65} %exponential.3 = f32[299]{0} exponential(f32[299]{0} %negate.37), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/exp" source_file="" source_line=65} %constant.213 = f32[] constant(0) %broadcast.206 = f32[299]{0} broadcast(f32[] %constant.213), dimensions={} %select.141 = f32[299]{0} select(pred[299]{0} %compare.136, f32[299]{0} %exponential.3, f32[299]{0} %broadcast.206), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/mul" source_file="" source_line=65} %constant.212 = s32[] constant(1) %broadcast.205 = s32[299]{0} broadcast(s32[] %constant.212), dimensions={} %convert.21 = s32[299]{0} convert(pred[299]{0} %compare.136), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/convert_element_type" source_file="" source_line=66} %subtract.21 = s32[299]{0} subtract(s32[299]{0} %broadcast.205, s32[299]{0} %convert.21), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/sub" source_file="" source_line=66} %convert.20 = f32[299]{0} convert(s32[299]{0} %subtract.21), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/convert_element_type" source_file="" source_line=66} %constant.211 = f32[] constant(1) %broadcast.204 = f32[299]{0} broadcast(f32[] %constant.211), dimensions={} %slice.15 = f32[1]{0} slice(f32[100000]{0} %param_5.47), slice={[99999:100000]}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.26 = f32[] bitcast(f32[1]{0} %slice.15), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.203 = f32[299]{0} broadcast(f32[] %bitcast.26), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/gt" source_file="" source_line=58} %compare.135 = pred[299]{0} compare(f32[299]{0} %subtract.22, f32[299]{0} %broadcast.203), direction=GT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/gt" source_file="" source_line=58} %iota.5 = s32[100000]{0} iota(), iota_dimension=0, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/iota" source_file="" source_line=58} %convert.19 = f32[100000]{0} convert(s32[100000]{0} %iota.5), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/convert_element_type" source_file="" source_line=58} %slice.14 = f32[1]{0} slice(f32[100000]{0} %convert.19), slice={[99999:100000]}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/dynamic_slice" source_file="" source_line=58} %bitcast.25 = f32[] bitcast(f32[1]{0} %slice.14), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/dynamic_slice" source_file="" source_line=58} %broadcast.202 = f32[299]{0} broadcast(f32[] %bitcast.25), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %slice.13 = f32[1]{0} slice(f32[100000]{0} %convert.19), slice={[0:1]}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/slice" source_file="" source_line=58} %bitcast.24 = f32[] bitcast(f32[1]{0} %slice.13), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/slice" source_file="" source_line=58} %broadcast.201 = f32[299]{0} broadcast(f32[] %bitcast.24), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(_where)/broadcast_in_dim" source_file="" source_line=58} %param_6.52 = s32[299]{0} parameter(6) %constant.209 = s32[] constant(99999) %broadcast.198 = s32[299]{0} broadcast(s32[] %constant.209), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %clamp.1 = s32[299]{0} clamp(s32[299]{0} %broadcast.205, s32[299]{0} %param_6.52, s32[299]{0} %broadcast.198), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %param_7.35 = s32[] parameter(7) %broadcast.200 = s32[299]{0} broadcast(s32[] %param_7.35), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.134 = pred[299]{0} compare(s32[299]{0} %clamp.1, s32[299]{0} %broadcast.200), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/lt" source_file="" source_line=58} %constant.210 = s32[] constant(100000) %broadcast.199 = s32[299]{0} broadcast(s32[] %constant.210), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %add.89 = s32[299]{0} add(s32[299]{0} %clamp.1, s32[299]{0} %broadcast.199), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/add" source_file="" source_line=58} %select.140 = s32[299]{0} select(pred[299]{0} %compare.134, s32[299]{0} %add.89, s32[299]{0} %clamp.1), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/select_n" source_file="" source_line=58} %bitcast.23 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.140), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/select_n" source_file="" source_line=58} %gather.21 = f32[299]{0} gather(f32[100000]{0} %param_5.47, s32[299,1]{1,0} %bitcast.23), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/gather" source_file="" source_line=58} %param_4.48 = f32[299]{0} parameter(4) %subtract.20 = f32[299]{0} subtract(f32[299]{0} %gather.21, f32[299]{0} %param_4.48), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/sub" source_file="" source_line=58} %abs.1 = f32[299]{0} abs(f32[299]{0} %subtract.20), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/abs" source_file="" source_line=58} %constant.208 = f32[] constant(1.42108547e-14) %broadcast.197 = f32[299]{0} broadcast(f32[] %constant.208), dimensions={} %compare.133 = pred[299]{0} compare(f32[299]{0} %abs.1, f32[299]{0} %broadcast.197), direction=LE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/le" source_file="" source_line=58} %param_3.37 = f32[299]{0} parameter(3) %subtract.19 = f32[299]{0} subtract(f32[299]{0} %subtract.22, f32[299]{0} %param_4.48), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/sub" source_file="" source_line=58} %select.139 = f32[299]{0} select(pred[299]{0} %compare.133, f32[299]{0} %broadcast.204, f32[299]{0} %subtract.20), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %divide.4 = f32[299]{0} divide(f32[299]{0} %subtract.19, f32[299]{0} %select.139), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/div" source_file="" source_line=58} %gather.20 = f32[299]{0} gather(f32[100000]{0} %convert.19, s32[299,1]{1,0} %bitcast.23), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/gather" source_file="" source_line=58} %subtract.18 = f32[299]{0} subtract(f32[299]{0} %gather.20, f32[299]{0} %param_3.37), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/sub" source_file="" source_line=58} %multiply.28 = f32[299]{0} multiply(f32[299]{0} %divide.4, f32[299]{0} %subtract.18), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/mul" source_file="" source_line=58} %add.87 = f32[299]{0} add(f32[299]{0} %param_3.37, f32[299]{0} %multiply.28), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/add" source_file="" source_line=58} %select.138 = f32[299]{0} select(pred[299]{0} %compare.133, f32[299]{0} %param_3.37, f32[299]{0} %add.87), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.137 = f32[299]{0} select(pred[299]{0} %compare.136, f32[299]{0} %broadcast.201, f32[299]{0} %select.138), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %select.136 = f32[299]{0} select(pred[299]{0} %compare.135, f32[299]{0} %broadcast.202, f32[299]{0} %select.137), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(_where)/select_n" source_file="" source_line=58} %floor.1 = f32[299]{0} floor(f32[299]{0} %select.136), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/floor" source_file="" source_line=60} %subtract.16 = f32[299]{0} subtract(f32[299]{0} %select.136, f32[299]{0} %floor.1), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/sub" source_file="" source_line=60} %subtract.15 = f32[299]{0} subtract(f32[299]{0} %broadcast.204, f32[299]{0} %subtract.16), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/sub" source_file="" source_line=60} %param_2.24 = s32[299]{0} parameter(2) %compare.132 = pred[299]{0} compare(s32[299]{0} %param_2.24, s32[299]{0} %broadcast.200), direction=GE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/ge" source_file="" source_line=60} %constant.207 = s32[] constant(300) %broadcast.196 = s32[299]{0} broadcast(s32[] %constant.207), dimensions={} %compare.131 = pred[299]{0} compare(s32[299]{0} %param_2.24, s32[299]{0} %broadcast.196), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/lt" source_file="" source_line=60} %and.34 = pred[299]{0} and(pred[299]{0} %compare.132, pred[299]{0} %compare.131), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/and" source_file="" source_line=60} %convert.18 = s32[299]{0} convert(f32[299]{0} %floor.1), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/convert_element_type" source_file="" source_line=60} %compare.130 = pred[299]{0} compare(s32[299]{0} %convert.18, s32[299]{0} %broadcast.200), direction=GE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/ge" source_file="" source_line=60} %compare.129 = pred[299]{0} compare(s32[299]{0} %convert.18, s32[299]{0} %broadcast.199), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/lt" source_file="" source_line=60} %and.33 = pred[299]{0} and(pred[299]{0} %compare.130, pred[299]{0} %compare.129), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/and" source_file="" source_line=60} %and.32 = pred[299]{0} and(pred[299]{0} %and.34, pred[299]{0} %and.33), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/and" source_file="" source_line=60} %param_1.21 = f32[300,100000]{1,0} parameter(1) %compare.128 = pred[299]{0} compare(s32[299]{0} %param_2.24, s32[299]{0} %broadcast.200), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/lt" source_file="" source_line=60} %add.86 = s32[299]{0} add(s32[299]{0} %param_2.24, s32[299]{0} %broadcast.196), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/add" source_file="" source_line=60} %select.135 = s32[299]{0} select(pred[299]{0} %compare.128, s32[299]{0} %add.86, s32[299]{0} %param_2.24), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/select_n" source_file="" source_line=60} %bitcast.22 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.135), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/select_n" source_file="" source_line=60} %compare.127 = pred[299]{0} compare(s32[299]{0} %convert.18, s32[299]{0} %broadcast.200), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/lt" source_file="" source_line=60} %add.84 = s32[299]{0} add(s32[299]{0} %convert.18, s32[299]{0} %broadcast.199), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/add" source_file="" source_line=60} %select.134 = s32[299]{0} select(pred[299]{0} %compare.127, s32[299]{0} %add.84, s32[299]{0} %convert.18), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/select_n" source_file="" source_line=60} %bitcast.21 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.134), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/select_n" source_file="" source_line=60} %concatenate.13 = s32[299,2]{1,0} concatenate(s32[299,1]{1,0} %bitcast.22, s32[299,1]{1,0} %bitcast.21), dimensions={1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/concatenate" source_file="" source_line=60} %gather.19 = f32[299]{0} gather(f32[300,100000]{1,0} %param_1.21, s32[299,2]{1,0} %concatenate.13), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/gather" source_file="" source_line=60} %select.133 = f32[299]{0} select(pred[299]{0} %and.32, f32[299]{0} %gather.19, f32[299]{0} %broadcast.206), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/jit(_where)/select_n" source_file="" source_line=60} %multiply.27 = f32[299]{0} multiply(f32[299]{0} %subtract.15, f32[299]{0} %select.133), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/mul" source_file="" source_line=60} %add.83 = s32[299]{0} add(s32[299]{0} %convert.18, s32[299]{0} %broadcast.205), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/add" source_file="" source_line=60} %compare.125 = pred[299]{0} compare(s32[299]{0} %add.83, s32[299]{0} %broadcast.200), direction=GE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/ge" source_file="" source_line=60} %compare.124 = pred[299]{0} compare(s32[299]{0} %add.83, s32[299]{0} %broadcast.199), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/lt" source_file="" source_line=60} %and.31 = pred[299]{0} and(pred[299]{0} %compare.125, pred[299]{0} %compare.124), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/and" source_file="" source_line=60} %and.30 = pred[299]{0} and(pred[299]{0} %and.34, pred[299]{0} %and.31), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/and" source_file="" source_line=60} %compare.123 = pred[299]{0} compare(s32[299]{0} %add.83, s32[299]{0} %broadcast.200), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/lt" source_file="" source_line=60} %constant.206 = s32[] constant(100001) %broadcast.195 = s32[299]{0} broadcast(s32[] %constant.206), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/add" source_file="" source_line=60} %add.82 = s32[299]{0} add(s32[299]{0} %convert.18, s32[299]{0} %broadcast.195), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/add" source_file="" source_line=60} %select.132 = s32[299]{0} select(pred[299]{0} %compare.123, s32[299]{0} %add.82, s32[299]{0} %add.83), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/select_n" source_file="" source_line=60} %bitcast.20 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.132), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/select_n" source_file="" source_line=60} %concatenate.12 = s32[299,2]{1,0} concatenate(s32[299,1]{1,0} %bitcast.22, s32[299,1]{1,0} %bitcast.20), dimensions={1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/concatenate" source_file="" source_line=60} %gather.18 = f32[299]{0} gather(f32[300,100000]{1,0} %param_1.21, s32[299,2]{1,0} %concatenate.12), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/gather" source_file="" source_line=60} %select.131 = f32[299]{0} select(pred[299]{0} %and.30, f32[299]{0} %gather.18, f32[299]{0} %broadcast.206), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/jit(_where)/select_n" source_file="" source_line=60} %multiply.25 = f32[299]{0} multiply(f32[299]{0} %subtract.16, f32[299]{0} %select.131), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/mul" source_file="" source_line=60} %add.81 = f32[299]{0} add(f32[299]{0} %multiply.27, f32[299]{0} %multiply.25), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/add" source_file="" source_line=60} %multiply.24 = f32[299]{0} multiply(f32[299]{0} %broadcast.206, f32[299]{0} %subtract.15), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/mul" source_file="" source_line=60} %add.80 = s32[299]{0} add(s32[299]{0} %param_2.24, s32[299]{0} %broadcast.205), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/add" source_file="" source_line=60} %compare.122 = pred[299]{0} compare(s32[299]{0} %add.80, s32[299]{0} %broadcast.200), direction=GE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/ge" source_file="" source_line=60} %compare.121 = pred[299]{0} compare(s32[299]{0} %add.80, s32[299]{0} %broadcast.196), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/lt" source_file="" source_line=60} %and.29 = pred[299]{0} and(pred[299]{0} %compare.122, pred[299]{0} %compare.121), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/and" source_file="" source_line=60} %and.28 = pred[299]{0} and(pred[299]{0} %and.29, pred[299]{0} %and.33), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/and" source_file="" source_line=60} %compare.120 = pred[299]{0} compare(s32[299]{0} %add.80, s32[299]{0} %broadcast.200), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/lt" source_file="" source_line=60} %constant.205 = s32[] constant(301) %broadcast.194 = s32[299]{0} broadcast(s32[] %constant.205), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/add" source_file="" source_line=60} %add.79 = s32[299]{0} add(s32[299]{0} %param_2.24, s32[299]{0} %broadcast.194), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/add" source_file="" source_line=60} %select.130 = s32[299]{0} select(pred[299]{0} %compare.120, s32[299]{0} %add.79, s32[299]{0} %add.80), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/select_n" source_file="" source_line=60} %bitcast.19 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.130), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/select_n" source_file="" source_line=60} %concatenate.11 = s32[299,2]{1,0} concatenate(s32[299,1]{1,0} %bitcast.19, s32[299,1]{1,0} %bitcast.21), dimensions={1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/concatenate" source_file="" source_line=60} %gather.17 = f32[299]{0} gather(f32[300,100000]{1,0} %param_1.21, s32[299,2]{1,0} %concatenate.11), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/gather" source_file="" source_line=60} %select.129 = f32[299]{0} select(pred[299]{0} %and.28, f32[299]{0} %gather.17, f32[299]{0} %broadcast.206), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/jit(_where)/select_n" source_file="" source_line=60} %multiply.23 = f32[299]{0} multiply(f32[299]{0} %multiply.24, f32[299]{0} %select.129), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/mul" source_file="" source_line=60} %add.78 = f32[299]{0} add(f32[299]{0} %add.81, f32[299]{0} %multiply.23), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/add" source_file="" source_line=60} %multiply.22 = f32[299]{0} multiply(f32[299]{0} %broadcast.206, f32[299]{0} %subtract.16), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/mul" source_file="" source_line=60} %and.27 = pred[299]{0} and(pred[299]{0} %and.29, pred[299]{0} %and.31), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/and" source_file="" source_line=60} %concatenate.10 = s32[299,2]{1,0} concatenate(s32[299,1]{1,0} %bitcast.19, s32[299,1]{1,0} %bitcast.20), dimensions={1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/concatenate" source_file="" source_line=60} %gather.16 = f32[299]{0} gather(f32[300,100000]{1,0} %param_1.21, s32[299,2]{1,0} %concatenate.10), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/gather" source_file="" source_line=60} %select.128 = f32[299]{0} select(pred[299]{0} %and.27, f32[299]{0} %gather.16, f32[299]{0} %broadcast.206), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/jit(_where)/select_n" source_file="" source_line=60} %multiply.20 = f32[299]{0} multiply(f32[299]{0} %multiply.22, f32[299]{0} %select.128), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/mul" source_file="" source_line=60} %add.77 = f32[299]{0} add(f32[299]{0} %add.78, f32[299]{0} %multiply.20), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_map_coordinates)/add" source_file="" source_line=60} %multiply.19 = f32[299]{0} multiply(f32[299]{0} %convert.20, f32[299]{0} %add.77), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/mul" source_file="" source_line=66} %add.76 = f32[299]{0} add(f32[299]{0} %select.141, f32[299]{0} %multiply.19), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/add" source_file="" source_line=65} %slice.12 = f32[1]{0} slice(f32[300]{0} %param_0.9), slice={[299:300]}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/dynamic_slice" source_file="" source_line=67} %negate.36 = f32[1]{0} negate(f32[1]{0} %slice.12), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/neg" source_file="" source_line=67} %exponential.2 = f32[1]{0} exponential(f32[1]{0} %negate.36), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/exp" source_file="" source_line=67} ROOT %concatenate.9 = f32[300]{0} concatenate(f32[299]{0} %add.76, f32[1]{0} %exponential.2), dimensions={0}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(append)/concatenate" source_file="" source_line=67} } %fused_computation.5 (param_0.12: s32[299]) -> f32[299] { %iota.6 = s32[100000]{0} iota(), iota_dimension=0, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/iota" source_file="" source_line=58} %convert.22 = f32[100000]{0} convert(s32[100000]{0} %iota.6), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/convert_element_type" source_file="" source_line=58} %constant.217 = s32[] constant(1) %broadcast.212 = s32[299]{0} broadcast(s32[] %constant.217), dimensions={} %param_0.12 = s32[299]{0} parameter(0) %constant.215 = s32[] constant(99999) %broadcast.210 = s32[299]{0} broadcast(s32[] %constant.215), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %clamp.2 = s32[299]{0} clamp(s32[299]{0} %broadcast.212, s32[299]{0} %param_0.12, s32[299]{0} %broadcast.210), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %constant.214 = s32[] constant(-1) %broadcast.209 = s32[299]{0} broadcast(s32[] %constant.214), dimensions={} %add.91 = s32[299]{0} add(s32[299]{0} %clamp.2, s32[299]{0} %broadcast.209), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/sub" source_file="" source_line=58} %constant.216 = s32[] constant(0) %broadcast.211 = s32[299]{0} broadcast(s32[] %constant.216), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.137 = pred[299]{0} compare(s32[299]{0} %add.91, s32[299]{0} %broadcast.211), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/lt" source_file="" source_line=58} %add.90 = s32[299]{0} add(s32[299]{0} %clamp.2, s32[299]{0} %broadcast.210), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/add" source_file="" source_line=58} %select.142 = s32[299]{0} select(pred[299]{0} %compare.137, s32[299]{0} %add.90, s32[299]{0} %add.91), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/select_n" source_file="" source_line=58} %bitcast.28 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.142), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/select_n" source_file="" source_line=58} ROOT %gather.22 = f32[299]{0} gather(f32[100000]{0} %convert.22, s32[299,1]{1,0} %bitcast.28), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/gather" source_file="" source_line=58} } %fused_computation.6 (param_0.13: f32[100000], param_1.40: s32[299]) -> f32[299] { %param_0.13 = f32[100000]{0} parameter(0) %constant.221 = s32[] constant(1) %broadcast.216 = s32[299]{0} broadcast(s32[] %constant.221), dimensions={} %param_1.40 = s32[299]{0} parameter(1) %constant.219 = s32[] constant(99999) %broadcast.214 = s32[299]{0} broadcast(s32[] %constant.219), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %clamp.3 = s32[299]{0} clamp(s32[299]{0} %broadcast.216, s32[299]{0} %param_1.40, s32[299]{0} %broadcast.214), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(clip)/min" source_file="" source_line=58} %constant.218 = s32[] constant(-1) %broadcast.213 = s32[299]{0} broadcast(s32[] %constant.218), dimensions={} %add.93 = s32[299]{0} add(s32[299]{0} %clamp.3, s32[299]{0} %broadcast.213), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/sub" source_file="" source_line=58} %constant.220 = s32[] constant(0) %broadcast.215 = s32[299]{0} broadcast(s32[] %constant.220), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.138 = pred[299]{0} compare(s32[299]{0} %add.93, s32[299]{0} %broadcast.215), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/lt" source_file="" source_line=58} %add.92 = s32[299]{0} add(s32[299]{0} %clamp.3, s32[299]{0} %broadcast.214), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/add" source_file="" source_line=58} %select.143 = s32[299]{0} select(pred[299]{0} %compare.138, s32[299]{0} %add.92, s32[299]{0} %add.93), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/select_n" source_file="" source_line=58} %bitcast.29 = s32[299,1]{1,0} bitcast(s32[299]{0} %select.143), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/select_n" source_file="" source_line=58} ROOT %gather.23 = f32[299]{0} gather(f32[100000]{0} %param_0.13, s32[299,1]{1,0} %bitcast.29), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/gather" source_file="" source_line=58} } %fused_computation.7 (param_0.20: f32[299], param_1.50: f32[]) -> s32[299] { %param_1.50 = f32[] parameter(1) %broadcast.221 = f32[299]{0} broadcast(f32[] %param_1.50), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/sub" source_file="" source_line=56} %param_0.20 = f32[299]{0} parameter(0) %subtract.23 = f32[299]{0} subtract(f32[299]{0} %broadcast.221, f32[299]{0} %param_0.20), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/sub" source_file="" source_line=56} %compare.141 = pred[299]{0} compare(f32[299]{0} %subtract.23, f32[299]{0} %subtract.23), direction=NE, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/ne" source_file="" source_line=58} %constant.223 = f32[] constant(nan) %broadcast.218 = f32[299]{0} broadcast(f32[] %constant.223), dimensions={} %constant.225 = f32[] constant(0) %broadcast.220 = f32[299]{0} broadcast(f32[] %constant.225), dimensions={} %compare.140 = pred[299]{0} compare(f32[299]{0} %subtract.23, f32[299]{0} %broadcast.220), direction=EQ, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/eq" source_file="" source_line=58} %select.146 = f32[299]{0} select(pred[299]{0} %compare.140, f32[299]{0} %broadcast.220, f32[299]{0} %subtract.23), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %select.145 = f32[299]{0} select(pred[299]{0} %compare.141, f32[299]{0} %broadcast.218, f32[299]{0} %select.146), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/vmap(while)/body/select_n" source_file="" source_line=58} %bitcast-convert.11 = s32[299]{0} bitcast-convert(f32[299]{0} %select.145) %constant.224 = s32[] constant(0) %broadcast.219 = s32[299]{0} broadcast(s32[] %constant.224), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %compare.139 = pred[299]{0} compare(s32[299]{0} %bitcast-convert.11, s32[299]{0} %broadcast.219), direction=LT %constant.222 = s32[] constant(2147483647) %broadcast.217 = s32[299]{0} broadcast(s32[] %constant.222), dimensions={} %xor.11 = s32[299]{0} xor(s32[299]{0} %broadcast.217, s32[299]{0} %bitcast-convert.11) ROOT %select.144 = s32[299]{0} select(pred[299]{0} %compare.139, s32[299]{0} %xor.11, s32[299]{0} %bitcast-convert.11) } %fused_computation.8 (param_0.23: f32[300]) -> f32[299] { %param_0.23 = f32[300]{0} parameter(0) %slice.19 = f32[299]{0} slice(f32[300]{0} %param_0.23), slice={[1:300]}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/slice" source_file="" source_line=56} %slice.18 = f32[299]{0} slice(f32[300]{0} %param_0.23), slice={[0:299]}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/slice" source_file="" source_line=56} %divide.5 = f32[299]{0} divide(f32[299]{0} %slice.19, f32[299]{0} %slice.18), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/div" source_file="" source_line=56} ROOT %log.0 = f32[299]{0} log(f32[299]{0} %divide.5), metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/log" source_file="" source_line=56} } %region_2.380 (arg_tuple.381: (f32[300], f32[], f32[100000], f32[300,100000], s32[299])) -> (f32[300]) { %arg_tuple.381 = (f32[300]{0}, f32[], f32[100000]{0}, f32[300,100000]{1,0}, s32[299]{0}) parameter(0) %get-tuple-element.382 = f32[300]{0} get-tuple-element((f32[300]{0}, f32[], f32[100000]{0}, f32[300,100000]{1,0}, s32[299]{0}) %arg_tuple.381), index=0 %fusion.8 = f32[299]{0} fusion(f32[300]{0} %get-tuple-element.382), kind=kLoop, calls=%fused_computation.8, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/log" source_file="" source_line=56} %get-tuple-element.383 = f32[] get-tuple-element((f32[300]{0}, f32[], f32[100000]{0}, f32[300,100000]{1,0}, s32[299]{0}) %arg_tuple.381), index=1 %fusion.7 = s32[299]{0} fusion(f32[299]{0} %fusion.8, f32[] %get-tuple-element.383), kind=kLoop, calls=%fused_computation.7 %get-tuple-element.117 = f32[100000]{0} get-tuple-element((f32[300]{0}, f32[], f32[100000]{0}, f32[300,100000]{1,0}, s32[299]{0}) %arg_tuple.381), index=2 %constant.25 = s32[] constant(0) %broadcast.44 = s32[299]{0} broadcast(s32[] %constant.25), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %constant.27 = s32[] constant(100000) %broadcast.45 = s32[299]{0} broadcast(s32[] %constant.27), dimensions={}, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/broadcast_in_dim" source_file="" source_line=58} %copy = s32[] copy(s32[] %constant.25) %tuple.25 = (s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) tuple(s32[] %copy, s32[299]{0} %broadcast.44, s32[299]{0} %broadcast.45, f32[100000]{0} %get-tuple-element.117, s32[299]{0} %fusion.7) %while.6 = (s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) while((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %tuple.25), condition=%wide.region_4.118.clone.clone, body=%wide.region_3.105.clone.clone, backend_config={"known_trip_count":{"n":"17"}} %get-tuple-element.1 = s32[299]{0} get-tuple-element((s32[], s32[299]{0}, s32[299]{0}, f32[100000]{0}, s32[299]{0}) %while.6), index=2, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/jit(searchsorted)/while" source_file="" source_line=58} %fusion.6 = f32[299]{0} fusion(f32[100000]{0} %get-tuple-element.117, s32[299]{0} %get-tuple-element.1), kind=kLoop, calls=%fused_computation.6, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/gather" source_file="" source_line=58} %fusion.5 = f32[299]{0} fusion(s32[299]{0} %get-tuple-element.1), kind=kLoop, calls=%fused_computation.5, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(_interp)/gather" source_file="" source_line=58} %get-tuple-element.385 = f32[300,100000]{1,0} get-tuple-element((f32[300]{0}, f32[], f32[100000]{0}, f32[300,100000]{1,0}, s32[299]{0}) %arg_tuple.381), index=3 %get-tuple-element.386 = s32[299]{0} get-tuple-element((f32[300]{0}, f32[], f32[100000]{0}, f32[300,100000]{1,0}, s32[299]{0}) %arg_tuple.381), index=4 %fusion.4 = f32[300]{0} fusion(f32[300]{0} %get-tuple-element.382, f32[300,100000]{1,0} %get-tuple-element.385, s32[299]{0} %get-tuple-element.386, f32[299]{0} %fusion.5, f32[299]{0} %fusion.6, /*index=5*/f32[100000]{0} %get-tuple-element.117, s32[299]{0} %get-tuple-element.1, s32[] %constant.25, f32[299]{0} %fusion.8, f32[] %get-tuple-element.383), kind=kLoop, calls=%fused_computation.4, metadata={op_name="jit(main2)/jit(main)/while/body/cond/branch_1_fun/jit(append)/concatenate" source_file="" source_line=67} ROOT %tuple.29 = (f32[300]{0}) tuple(f32[300]{0} %fusion.4) } %fused_computation.9 (param_0.24: f32[300,100000], param_1.55: s32[], param_2.50: f32[300], param_3.59: pred[]) -> f32[300,100000] { %param_0.24 = f32[300,100000]{1,0} parameter(0) %param_3.59 = pred[] parameter(3) %broadcast.222 = pred[300,1]{1,0} broadcast(pred[] %param_3.59), dimensions={} %param_2.50 = f32[300]{0} parameter(2) %bitcast.30 = f32[300,1]{1,0} bitcast(f32[300]{0} %param_2.50), metadata={op_name="jit(main2)/jit(main)/while/body/cond" source_file="" source_line=40} %constant.226 = s32[] constant(0) %param_1.55 = s32[] parameter(1) %dynamic-slice.9 = f32[300,1]{1,0} dynamic-slice(f32[300,100000]{1,0} %param_0.24, s32[] %constant.226, s32[] %param_1.55), dynamic_slice_sizes={300,1} %select.147 = f32[300,1]{1,0} select(pred[300,1]{1,0} %broadcast.222, f32[300,1]{1,0} %bitcast.30, f32[300,1]{1,0} %dynamic-slice.9) ROOT %dynamic-update-slice.4 = f32[300,100000]{1,0} dynamic-update-slice(f32[300,100000]{1,0} %param_0.24, f32[300,1]{1,0} %select.147, s32[] %constant.226, s32[] %param_1.55), metadata={op_name="jit(main2)/jit(main)/while/body/scatter" source_file="" source_line=73} } %fused_computation.10 (param_0.26: f32[100000], param_1.56: s32[]) -> f32[] { %param_0.26 = f32[100000]{0} parameter(0) %param_1.56 = s32[] parameter(1) %dynamic-slice.10 = f32[1]{0} dynamic-slice(f32[100000]{0} %param_0.26, s32[] %param_1.56), dynamic_slice_sizes={1}, metadata={op_name="jit(main2)/jit(main)/while/body/dynamic_slice" source_file="" source_line=39} ROOT %bitcast.31 = f32[] bitcast(f32[1]{0} %dynamic-slice.10), metadata={op_name="jit(main2)/jit(main)/while/body/dynamic_slice" source_file="" source_line=39} } %fused_computation.11 (param_0.30: s32[100000], param_1.60: s32[]) -> s32[] { %param_0.30 = s32[100000]{0} parameter(0) %param_1.60 = s32[] parameter(1) %constant.229 = s32[] constant(0) %compare.143 = pred[] compare(s32[] %param_1.60, s32[] %constant.229), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/lt" source_file="" source_line=46} %constant.228 = s32[] constant(100000) %add.94 = s32[] add(s32[] %param_1.60, s32[] %constant.228), metadata={op_name="jit(main2)/jit(main)/while/body/add" source_file="" source_line=46} %select.148 = s32[] select(pred[] %compare.143, s32[] %add.94, s32[] %param_1.60), metadata={op_name="jit(main2)/jit(main)/while/body/select_n" source_file="" source_line=46} %dynamic-slice.11 = s32[1]{0} dynamic-slice(s32[100000]{0} %param_0.30, s32[] %select.148), dynamic_slice_sizes={1}, metadata={op_name="jit(main2)/jit(main)/while/body/dynamic_slice" source_file="" source_line=46} %bitcast.32 = s32[] bitcast(s32[1]{0} %dynamic-slice.11), metadata={op_name="jit(main2)/jit(main)/while/body/dynamic_slice" source_file="" source_line=46} %constant.227 = s32[] constant(-1) %compare.142 = pred[] compare(s32[] %bitcast.32, s32[] %constant.227), direction=GT, metadata={op_name="jit(main2)/jit(main)/while/body/gt" source_file="" source_line=40} ROOT %convert.23 = s32[] convert(pred[] %compare.142), metadata={op_name="jit(main2)/jit(main)/while/body/convert_element_type" source_file="" source_line=40} } %and.reduce_sub_computation (lhs: pred[], rhs: pred[]) -> pred[] { %lhs = pred[] parameter(0) %rhs = pred[] parameter(1) ROOT %and.10 = pred[] and(pred[] %lhs, pred[] %rhs) } %fused_computation.12 (param_0.34: s32[2], param_1.65: s32[2]) -> pred[] { %constant.231 = s32[] constant(0) %broadcast.223 = s32[2]{0} broadcast(s32[] %constant.231), dimensions={} %param_1.65 = s32[2]{0} parameter(1) %compare.145 = pred[2]{0} compare(s32[2]{0} %broadcast.223, s32[2]{0} %param_1.65), direction=LE %param_0.34 = s32[2]{0} parameter(0) %compare.144 = pred[2]{0} compare(s32[2]{0} %param_0.34, s32[2]{0} %param_1.65), direction=GE %and.35 = pred[2]{0} and(pred[2]{0} %compare.145, pred[2]{0} %compare.144) %constant.230 = pred[] constant(true) ROOT %reduce.3 = pred[] reduce(pred[2]{0} %and.35, pred[] %constant.230), dimensions={0}, to_apply=%and.reduce_sub_computation } %fused_computation.13 (param_0.36: s32[]) -> s32[2] { %constant.232 = s32[1]{0} constant({0}) %param_0.36 = s32[] parameter(0) %bitcast.33 = s32[1]{0} bitcast(s32[] %param_0.36), metadata={op_name="jit(main2)/jit(main)/while/body/select_n" source_file="" source_line=73} ROOT %concatenate.14 = s32[2]{0} concatenate(s32[1]{0} %constant.232, s32[1]{0} %bitcast.33), dimensions={0} } %fused_computation.14 (param_0.41: s32[100000], param_1.73: s32[]) -> s32[] { %param_0.41 = s32[100000]{0} parameter(0) %param_1.73 = s32[] parameter(1) %constant.234 = s32[] constant(0) %compare.147 = pred[] compare(s32[] %param_1.73, s32[] %constant.234), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/lt" source_file="" source_line=46} %constant.233 = s32[] constant(100000) %add.96 = s32[] add(s32[] %param_1.73, s32[] %constant.233), metadata={op_name="jit(main2)/jit(main)/while/body/add" source_file="" source_line=46} %select.150 = s32[] select(pred[] %compare.147, s32[] %add.96, s32[] %param_1.73), metadata={op_name="jit(main2)/jit(main)/while/body/select_n" source_file="" source_line=46} %dynamic-slice.12 = s32[1]{0} dynamic-slice(s32[100000]{0} %param_0.41, s32[] %select.150), dynamic_slice_sizes={1}, metadata={op_name="jit(main2)/jit(main)/while/body/dynamic_slice" source_file="" source_line=46} %bitcast.34 = s32[] bitcast(s32[1]{0} %dynamic-slice.12), metadata={op_name="jit(main2)/jit(main)/while/body/dynamic_slice" source_file="" source_line=46} %compare.146 = pred[] compare(s32[] %bitcast.34, s32[] %constant.234), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/body/lt" source_file="" source_line=73} %add.95 = s32[] add(s32[] %bitcast.34, s32[] %constant.233), metadata={op_name="jit(main2)/jit(main)/while/body/add" source_file="" source_line=73} ROOT %select.149 = s32[] select(pred[] %compare.146, s32[] %add.95, s32[] %bitcast.34), metadata={op_name="jit(main2)/jit(main)/while/body/select_n" source_file="" source_line=73} } %wide.region_0.443 (wide.arg_tuple.444: (s32[], f32[300,100000], s32[100000], f32[100000], s32[2])) -> (s32[], f32[300,100000], s32[100000], f32[100000], s32[2]) { %wide.arg_tuple.444 = (s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) parameter(0) %get-tuple-element.144 = s32[] get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) %wide.arg_tuple.444), index=0 %copy.14 = s32[] copy(s32[] %get-tuple-element.144) %get-tuple-element.151 = s32[100000]{0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) %wide.arg_tuple.444), index=2 %fusion.14 = s32[] fusion(s32[100000]{0} %get-tuple-element.151, s32[] %copy.14), kind=kLoop, calls=%fused_computation.14, metadata={op_name="jit(main2)/jit(main)/while/body/select_n" source_file="" source_line=73} %get-tuple-element.152 = f32[100000]{0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) %wide.arg_tuple.444), index=3 %fusion.10 = f32[] fusion(f32[100000]{0} %get-tuple-element.152, s32[] %fusion.14), kind=kLoop, calls=%fused_computation.10, metadata={op_name="jit(main2)/jit(main)/while/body/dynamic_slice" source_file="" source_line=39} %get-tuple-element.145 = f32[300,100000]{1,0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) %wide.arg_tuple.444), index=1 %constant.158 = f32[300]{0} constant({...}) %constant.159 = s32[299]{0} constant({...}) %tuple.18 = (f32[300]{0}, f32[], f32[100000]{0}, f32[300,100000]{1,0}, s32[299]{0}) tuple(f32[300]{0} %constant.158, f32[] %fusion.10, f32[100000]{0} %get-tuple-element.152, f32[300,100000]{1,0} %get-tuple-element.145, s32[299]{0} %constant.159), metadata={op_name="jit(main2)/jit(main)/while/body/cond" source_file="" source_line=40} %fusion.11 = s32[] fusion(s32[100000]{0} %get-tuple-element.151, s32[] %copy.14), kind=kLoop, calls=%fused_computation.11, metadata={op_name="jit(main2)/jit(main)/while/body/convert_element_type" source_file="" source_line=40} %conditional.0.clone.1 = (f32[300]{0}) conditional(s32[] %fusion.11, f32[300]{0} %constant.158, (f32[300]{0}, f32[], f32[100000]{0}, f32[300,100000]{1,0}, s32[299]{0}) %tuple.18), branch_computations={%region_1.33, %region_2.380}, metadata={op_name="jit(main2)/jit(main)/while/body/cond" source_file="" source_line=40} %get-tuple-element.79 = f32[300]{0} get-tuple-element((f32[300]{0}) %conditional.0.clone.1), index=0, metadata={op_name="jit(main2)/jit(main)/while/body/cond" source_file="" source_line=40} %fusion.13 = s32[2]{0} fusion(s32[] %fusion.14), kind=kLoop, calls=%fused_computation.13 %get-tuple-element.153 = s32[2]{0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) %wide.arg_tuple.444), index=4 %fusion.12 = pred[] fusion(s32[2]{0} %get-tuple-element.153, s32[2]{0} %fusion.13), kind=kLoop, calls=%fused_computation.12 %fusion.9 = f32[300,100000]{1,0} fusion(f32[300,100000]{1,0} %get-tuple-element.145, s32[] %fusion.14, f32[300]{0} %get-tuple-element.79, pred[] %fusion.12), kind=kLoop, calls=%fused_computation.9, metadata={op_name="jit(main2)/jit(main)/while/body/scatter" source_file="" source_line=73} %constant.148 = s32[] constant(1) %add.61 = s32[] add(s32[] %copy.14, s32[] %constant.148), metadata={op_name="jit(main2)/jit(main)/while/body/add" source_file="" source_line=46} ROOT %tuple.33 = (s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) tuple(s32[] %add.61, f32[300,100000]{1,0} %fusion.9, s32[100000]{0} %get-tuple-element.151, f32[100000]{0} %get-tuple-element.152, s32[2]{0} %get-tuple-element.153) } %wide.region_6.462 (wide.arg_tuple.463: (s32[], f32[300,100000], s32[100000], f32[100000], s32[2])) -> pred[] { %constant.147 = s32[] constant(100000) %wide.arg_tuple.463 = (s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) parameter(0) %get-tuple-element.70 = s32[] get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) %wide.arg_tuple.463), index=0 ROOT %compare.79 = pred[] compare(s32[] %get-tuple-element.70, s32[] %constant.147), direction=LT, metadata={op_name="jit(main2)/jit(main)/while/cond/lt" source_file="" source_line=46} } %fused_computation.15 () -> s32[2] { %iota.7 = s32[2]{0} iota(), iota_dimension=0 %constant.235 = s32[] constant(99999) %broadcast.224 = s32[2]{0} broadcast(s32[] %constant.235), dimensions={} ROOT %multiply.29 = s32[2]{0} multiply(s32[2]{0} %iota.7, s32[2]{0} %broadcast.224) } %fused_computation.16 () -> f32[100000] { %constant.239 = f32[] constant(0) %broadcast.227 = f32[99999]{0} broadcast(f32[] %constant.239), dimensions={}, metadata={op_name="jit(main2)/jit(main)/jit(_linspace)/mul" source_file="" source_line=12} %constant.238 = f32[] constant(1) %broadcast.226 = f32[99999]{0} broadcast(f32[] %constant.238), dimensions={} %iota.8 = f32[99999]{0} iota(), iota_dimension=0, metadata={op_name="jit(main2)/jit(main)/jit(_linspace)/iota" source_file="" source_line=12} %constant.237 = f32[] constant(1.00001e-05) %broadcast.225 = f32[99999]{0} broadcast(f32[] %constant.237), dimensions={} %multiply.31 = f32[99999]{0} multiply(f32[99999]{0} %iota.8, f32[99999]{0} %broadcast.225), metadata={op_name="jit(main2)/jit(main)/jit(_linspace)/div" source_file="" source_line=12} %subtract.24 = f32[99999]{0} subtract(f32[99999]{0} %broadcast.226, f32[99999]{0} %multiply.31), metadata={op_name="jit(main2)/jit(main)/jit(_linspace)/sub" source_file="" source_line=12} %multiply.30 = f32[99999]{0} multiply(f32[99999]{0} %broadcast.227, f32[99999]{0} %subtract.24), metadata={op_name="jit(main2)/jit(main)/jit(_linspace)/mul" source_file="" source_line=12} %add.97 = f32[99999]{0} add(f32[99999]{0} %multiply.30, f32[99999]{0} %multiply.31), metadata={op_name="jit(main2)/jit(main)/jit(_linspace)/add" source_file="" source_line=12} %constant.236 = f32[1]{0} constant({1}), metadata={op_name="jit(main2)/jit(main)/jit(_linspace)/broadcast_in_dim" source_file="" source_line=12} ROOT %concatenate.15 = f32[100000]{0} concatenate(f32[99999]{0} %add.97, f32[1]{0} %constant.236), dimensions={0}, metadata={op_name="jit(main2)/jit(main)/jit(_linspace)/concatenate" source_file="" source_line=12} } %parallel_broadcast.2 (p: f32[]) -> f32[300,100000] { %p = f32[] parameter(0) ROOT %broadcast.2.clone = f32[300,100000]{1,0} broadcast(f32[] %p), dimensions={}, backend_config={"outer_dimension_partitions":["2"]} } ENTRY %main.473 () -> f32[300,100000] { %constant.1 = f32[] constant(0) %call.5 = f32[300,100000]{1,0} call(f32[] %constant.1), to_apply=%parallel_broadcast.2 %fusion.16 = f32[100000]{0} fusion(), kind=kLoop, calls=%fused_computation.16, metadata={op_name="jit(main2)/jit(main)/jit(_linspace)/concatenate" source_file="" source_line=12} %iota.31 = s32[100000]{0} iota(), iota_dimension=0, metadata={op_name="jit(main2)/jit(main)/iota" source_file="" source_line=46} %fusion.15 = s32[2]{0} fusion(), kind=kLoop, calls=%fused_computation.15 %constant.3 = s32[] constant(0) %copy.18 = s32[] copy(s32[] %constant.3) %tuple.30 = (s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) tuple(s32[] %copy.18, f32[300,100000]{1,0} %call.5, s32[100000]{0} %iota.31, f32[100000]{0} %fusion.16, s32[2]{0} %fusion.15) %while.5 = (s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) while((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) %tuple.30), condition=%wide.region_6.462, body=%wide.region_0.443, backend_config={"known_trip_count":{"n":"100000"}} ROOT %get-tuple-element.472 = f32[300,100000]{1,0} get-tuple-element((s32[], f32[300,100000]{1,0}, s32[100000]{0}, f32[100000]{0}, s32[2]{0}) %while.5), index=1, metadata={op_name="jit(main2)/jit(main)/while" source_file="" source_line=46} } ```

Granted, it's a lot of HLO, so I'm not sure how to understand what fusion logic led to different runtime characteristics. It may be worth reporting this upstream in the XLA project, because it seems that the compiler is missing an available optimization in the simpler case 1.

cgiovanetti commented 2 weeks ago

Also, we tried to add a line A = lax.optimization_barrier(A) to Case 1 and it didn't resolve the issue, which is maybe interesting--it seems the issue is not just that lax.cond is an optimization barrier