openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.69k stars 434 forks source link

Large GPU memory allocation for newer versions #17124

Open pwithams opened 1 month ago

pwithams commented 1 month ago

This is a more XLA-specific version of https://github.com/google/jax/issues/23548, encountered using an NVIDIA GPU.

Basically when using a value of dls=jnp.ones(shape=(1590, 3)) the program ran successfully and pprof reported ~500kB of memory usage, but increasing to dls=jnp.ones(shape=(1600, 3)) fails trying to allocate ~5GB. So it seems like there might be a difference between request GPU memory and actual usage.

The Jax script described in that ticket works on Jax v0.4.14 but not on versions v0.4.16 and above. The python xla_client version for v0.4.14 is 174 and for v0.4.16 is 194, so I'm thinking the reason for this change in behaviour must have occurred between 5ca49a9 and 326f72f. I see there were some changes/additions to cudnn fusion logic between those commits, and the last line of module_0000.jit_run.sm_8.9_gpu_after_optimizations.txt seems to reference "fusion" rather than "reduce" as before, not sure if that is relevant.

I've provided content of the module_0000.jit_run.sm_8.9_gpu_after_optimizations.txt files below. Let me know if there's any other information I can provide.

xla_dump output

jax==0.4.14, xla_client._version=174, module_0000.jit_run.sm_8.9_gpu_after_optimizations.txt

HloModule jit_run, is_scheduled=true, entry_computation_layout={()->f32[3540,71]{1,0}}, allow_spmd_sharding_propagation_to_output={true}

region_2.133 {
  Arg_1.135 = f32[] parameter(1)
  Arg_0.134 = f32[] parameter(0)
  ROOT add.14 = f32[] add(Arg_0.134, Arg_1.135), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/code/oom_gh_example.py" source_line=49}
}

fused_computation {
  constant_184 = f32[] constant(1)
  broadcast.210 = f32[71,1]{1,0} broadcast(constant_184), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71) broadcast_dimensions=()]" source_file="/code/oom_gh_example.py" source_line=74}
  iota.23 = s32[71,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/iota[dtype=int32 shape=(71,) dimension=0]" source_file="/code/oom_gh_example.py" source_line=77}
  gather.14 = f32[71,1,1]{2,1,0} gather(broadcast.210, iota.23), offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2), collapsed_slice_dims=(), start_index_map=(1,)) slice_sizes=(1, 1) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=63}
  constant_115 = f32[] constant(0.00052185118)
  broadcast.139 = f32[71,1,1]{2,1,0} broadcast(constant_115), dimensions={}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
  multiply.28 = f32[71,1,1]{2,1,0} multiply(gather.14, broadcast.139), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
  bitcast.183 = f32[71]{0} bitcast(multiply.28)
  broadcast.138 = f32[1,3540,71,75,71]{4,3,2,1,0} broadcast(bitcast.183), dimensions={2}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
  param_1.20 = f32[71,75,71,3]{3,2,1,0} parameter(1)
  broadcast.137 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(param_1.20), dimensions={2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
  bitcast.182 = f32[3540,71,75,71,3]{4,3,2,1,0} bitcast(broadcast.137)
  constant_123 = f32[] constant(0)
  reduce.10 = f32[3540,71,75,71]{3,2,1,0} reduce(bitcast.182, constant_123), dimensions={4}, to_apply=region_2.133
  bitcast.181 = f32[1,3540,71,75,71]{4,3,2,1,0} bitcast(reduce.10), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/code/oom_gh_example.py" source_line=49}
  constant_114 = f32[] constant(-0.5)
  broadcast.136 = f32[1,3540,71,75,71]{4,3,2,1,0} broadcast(constant_114), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=49}
  multiply.27 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(bitcast.181, broadcast.136), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=49}
  exponential.3 = f32[1,3540,71,75,71]{4,3,2,1,0} exponential(multiply.27), metadata={op_name="jit(run)/jit(main)/exp" source_file="/code/oom_gh_example.py" source_line=49}
  param_0.14 = f32[378075,3]{1,0} parameter(0)
  constant_121 = f32[] constant(-1)
  broadcast.135 = f32[378075,3]{1,0} broadcast(constant_121), dimensions={}
  add.21 = f32[378075,3]{1,0} add(param_0.14, broadcast.135), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=61}
  multiply.26 = f32[378075,3]{1,0} multiply(add.21, add.21), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=50}
  constant_122 = f32[] constant(0.189035922)
  broadcast.134 = f32[378075,3]{1,0} broadcast(constant_122), dimensions={}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=49}
  multiply.25 = f32[378075,3]{1,0} multiply(multiply.26, broadcast.134), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=50}
  bitcast.180 = f32[71,75,71,3]{3,2,1,0} bitcast(multiply.25)
  broadcast.133 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(bitcast.180), dimensions={2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=61}
  bitcast.179 = f32[3540,71,75,71,3]{4,3,2,1,0} bitcast(broadcast.133)
  reduce.9 = f32[3540,71,75,71]{3,2,1,0} reduce(bitcast.179, constant_123), dimensions={4}, to_apply=region_2.133
  bitcast.178 = f32[1,3540,71,75,71]{4,3,2,1,0} bitcast(reduce.9), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/code/oom_gh_example.py" source_line=50}
  multiply.24 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(bitcast.178, broadcast.136), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=50}
  exponential.2 = f32[1,3540,71,75,71]{4,3,2,1,0} exponential(multiply.24), metadata={op_name="jit(run)/jit(main)/exp" source_file="/code/oom_gh_example.py" source_line=50}
  add.20 = f32[1,3540,71,75,71]{4,3,2,1,0} add(exponential.3, exponential.2), metadata={op_name="jit(run)/jit(main)/add" source_file="/code/oom_gh_example.py" source_line=49}
  multiply.23 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(broadcast.138, add.20), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
  bitcast.177 = f32[3540,5325,71]{2,1,0} bitcast(multiply.23)
  pad.2 = f32[3540,5376,71]{2,1,0} pad(bitcast.177, constant_123), padding=0_0x0_51x0_0
  bitcast.176 = f32[3540,64,84,71]{3,2,1,0} bitcast(pad.2)
  ROOT reduce.8 = f32[3540,64,71]{2,1,0} reduce(bitcast.176, constant_123), dimensions={2}, to_apply=region_2.133, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/code/oom_gh_example.py" source_line=89}
} // fused_computation

fused_computation.1 {
  param_0.25 = f32[1,71,71,3,75]{4,3,2,1,0} parameter(0)
  constant_181 = f32[] constant(1)
  broadcast.208 = f32[1,71,75,11,3]{4,3,2,1,0} broadcast(constant_181), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 11, 3) broadcast_dimensions=(1, 2, 3, 4)]" source_file="/code/oom_gh_example.py" source_line=9}
  broadcast.207 = f32[1,71,75,71,3]{4,3,2,1,0} broadcast(constant_181), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/code/oom_gh_example.py" source_line=9}
  concatenate.24 = f32[1,71,75,164,3]{4,3,2,1,0} concatenate(broadcast.208, broadcast.207, broadcast.208, broadcast.207), dimensions={3}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=3]" source_file="/code/oom_gh_example.py" source_line=15}
  bitcast.212 = f32[71,164,3,1,75]{2,1,4,0,3} bitcast(concatenate.24)
  iota.21 = s32[71,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/code/oom_gh_example.py" source_line=15}
  constant_180 = s32[] constant(82)
  broadcast.206 = s32[71]{0} broadcast(constant_180), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/code/oom_gh_example.py" source_line=15}
  iota.20 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/iota[dtype=int32 shape=(71,) dimension=0]" source_file="/code/oom_gh_example.py" source_line=77}
  subtract.23 = s32[71]{0} subtract(broadcast.206, iota.20), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/code/oom_gh_example.py" source_line=15}
  constant_175 = s32[] constant(0)
  broadcast.205 = s32[71]{0} broadcast(constant_175), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/code/oom_gh_example.py" source_line=15}
  compare.26 = pred[71]{0} compare(subtract.23, broadcast.205), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/code/oom_gh_example.py" source_line=15}
  constant_178 = s32[] constant(246)
  broadcast.204 = s32[71]{0} broadcast(constant_178), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/code/oom_gh_example.py" source_line=15}
  subtract.22 = s32[71]{0} subtract(broadcast.204, iota.20), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/code/oom_gh_example.py" source_line=15}
  select.17 = s32[71]{0} select(compare.26, subtract.22, subtract.23), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/select_n" source_file="/code/oom_gh_example.py" source_line=15}
  bitcast.211 = s32[71,1]{1,0} bitcast(select.17)
  broadcast.203 = s32[71,1]{1,0} broadcast(constant_175), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(1,)]" source_file="/code/oom_gh_example.py" source_line=15}
  concatenate.23 = s32[71,3]{1,0} concatenate(iota.21, bitcast.211, broadcast.203), dimensions={1}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=1]" source_file="/code/oom_gh_example.py" source_line=15}
  gather.12 = f32[71,1,82,3,1,75]{5,4,3,2,1,0} gather(bitcast.212, concatenate.23), offset_dims={1,2,3,4,5}, collapsed_slice_dims={}, start_index_map={0,1,2}, index_vector_dim=1, slice_sizes={1,82,3,1,75}, indices_are_sorted=true, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2, 3, 4), collapsed_slice_dims=(1,), start_index_map=(1, 3, 4)) slice_sizes=(1, 1, 75, 82, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=15}
  bitcast.210 = f32[1,71,82,3,75]{4,3,2,1,0} bitcast(gather.12)
  slice.27 = f32[1,71,71,1,75]{4,3,2,1,0} slice(bitcast.210), slice={[0:1], [0:71], [0:71], [2:3], [0:75]}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2, 3, 4), collapsed_slice_dims=(), start_index_map=(3, 4)) slice_sizes=(1, 71, 75, 71, 1) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=17}
  constant_177 = f32[] constant(-0.03)
  broadcast.202 = f32[1,71,71,1,75]{4,3,2,1,0} broadcast(constant_177), dimensions={}
  add.32 = f32[1,71,71,1,75]{4,3,2,1,0} add(slice.27, broadcast.202), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=17}
  abs.6 = f32[1,71,71,1,75]{4,3,2,1,0} abs(add.32), metadata={op_name="jit(run)/jit(main)/abs" source_file="/code/oom_gh_example.py" source_line=17}
  constant_176 = f32[] constant(0.03)
  broadcast.201 = f32[1,71,71,1,75]{4,3,2,1,0} broadcast(constant_176), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=17}
  add.31 = f32[1,71,71,1,75]{4,3,2,1,0} add(abs.6, broadcast.201), metadata={op_name="jit(run)/jit(main)/add" source_file="/code/oom_gh_example.py" source_line=17}
  constant_174 = s32[] constant(2)
  dynamic-update-slice.8 = f32[1,71,71,3,75]{4,3,2,1,0} dynamic-update-slice(param_0.25, add.31, constant_175, constant_175, constant_175, /*index=5*/constant_174, constant_175), metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/code/oom_gh_example.py" source_line=17}
  transpose.15 = f32[1,71,75,71,3]{4,3,2,1,0} transpose(dynamic-update-slice.8), dimensions={0,1,4,2,3}, metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/code/oom_gh_example.py" source_line=17}
  bitcast.184 = f32[378075,3]{1,0} bitcast(transpose.15)
  constant_125_clone_1 = f32[] constant(-1)
  broadcast.141.clone.1 = f32[1,71,71,3,75]{4,3,2,1,0} broadcast(constant_125_clone_1), dimensions={}
  add.22.clone.1 = f32[1,71,71,3,75]{4,3,2,1,0} add(dynamic-update-slice.8, broadcast.141.clone.1), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
  multiply.30.clone.1 = f32[1,71,71,3,75]{4,3,2,1,0} multiply(add.22.clone.1, add.22.clone.1), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=49}
  constant_126_clone_1 = f32[] constant(0.189035922)
  broadcast.140.clone.1 = f32[1,71,71,3,75]{4,3,2,1,0} broadcast(constant_126_clone_1), dimensions={}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=49}
  multiply.29.clone.1 = f32[1,71,71,3,75]{4,3,2,1,0} multiply(multiply.30.clone.1, broadcast.140.clone.1), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=49}
  bitcast.185.clone.1 = f32[71,71,3,75]{3,2,1,0} bitcast(multiply.29.clone.1)
  transpose.16.clone.1 = f32[71,75,71,3]{3,2,1,0} transpose(bitcast.185.clone.1), dimensions={0,3,1,2}
  ROOT tuple.5 = (f32[378075,3]{1,0}, f32[71,75,71,3]{3,2,1,0}) tuple(bitcast.184, transpose.16.clone.1)
} // fused_computation.1

fused_computation.6 {
  constant_138 = f32[] constant(1)
  broadcast.164 = f32[1,71,75,11,3]{4,3,2,1,0} broadcast(constant_138), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 11, 3) broadcast_dimensions=(1, 2, 3, 4)]" source_file="/code/oom_gh_example.py" source_line=9}
  broadcast.163 = f32[1,71,75,71,3]{4,3,2,1,0} broadcast(constant_138), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/code/oom_gh_example.py" source_line=9}
  concatenate.12 = f32[1,71,75,164,3]{4,3,2,1,0} concatenate(broadcast.164, broadcast.163, broadcast.164, broadcast.163), dimensions={3}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=3]" source_file="/code/oom_gh_example.py" source_line=15}
  bitcast.194 = f32[71,164,3,1,75]{2,1,4,0,3} bitcast(concatenate.12)
  iota.9 = s32[71,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/code/oom_gh_example.py" source_line=15}
  constant_137 = s32[] constant(82)
  broadcast.162 = s32[71]{0} broadcast(constant_137), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/code/oom_gh_example.py" source_line=15}
  iota.8 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/iota[dtype=int32 shape=(71,) dimension=0]" source_file="/code/oom_gh_example.py" source_line=77}
  subtract.11 = s32[71]{0} subtract(broadcast.162, iota.8), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/code/oom_gh_example.py" source_line=15}
  constant_136 = s32[] constant(0)
  broadcast.161 = s32[71]{0} broadcast(constant_136), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/code/oom_gh_example.py" source_line=15}
  compare.20 = pred[71]{0} compare(subtract.11, broadcast.161), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/code/oom_gh_example.py" source_line=15}
  constant_135 = s32[] constant(246)
  broadcast.160 = s32[71]{0} broadcast(constant_135), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/code/oom_gh_example.py" source_line=15}
  subtract.10 = s32[71]{0} subtract(broadcast.160, iota.8), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/code/oom_gh_example.py" source_line=15}
  select.11 = s32[71]{0} select(compare.20, subtract.10, subtract.11), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/select_n" source_file="/code/oom_gh_example.py" source_line=15}
  bitcast.193 = s32[71,1]{1,0} bitcast(select.11)
  broadcast.159 = s32[71,1]{1,0} broadcast(constant_136), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(1,)]" source_file="/code/oom_gh_example.py" source_line=15}
  concatenate.11 = s32[71,3]{1,0} concatenate(iota.9, bitcast.193, broadcast.159), dimensions={1}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=1]" source_file="/code/oom_gh_example.py" source_line=15}
  gather.6 = f32[71,1,82,3,1,75]{5,4,3,2,1,0} gather(bitcast.194, concatenate.11), offset_dims={1,2,3,4,5}, collapsed_slice_dims={}, start_index_map={0,1,2}, index_vector_dim=1, slice_sizes={1,82,3,1,75}, indices_are_sorted=true, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2, 3, 4), collapsed_slice_dims=(1,), start_index_map=(1, 3, 4)) slice_sizes=(1, 1, 75, 82, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=15}
  bitcast.192 = f32[1,71,82,3,75]{4,3,2,1,0} bitcast(gather.6)
  ROOT slice.23 = f32[1,71,71,3,75]{4,3,2,1,0} slice(bitcast.192), slice={[0:1], [0:71], [0:71], [0:3], [0:75]}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2, 3, 4), collapsed_slice_dims=(), start_index_map=(3, 4)) slice_sizes=(1, 71, 75, 71, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=16}
} // fused_computation.6

ENTRY main.160 {
  constant_112 = f32[] constant(0)
  constant_109 = f32[3,3]{1,0} constant({ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, -1 } })
  fusion.6 = f32[1,71,71,3,75]{4,3,2,1,0} fusion(), kind=kLoop, calls=fused_computation.6, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2, 3, 4), collapsed_slice_dims=(), start_index_map=(3, 4)) slice_sizes=(1, 71, 75, 71, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=16}, backend_config={"kind":"","reification_cost":{"end_to_end_cycles":37920.934200644493}}
  fusion.1 = (f32[378075,3]{1,0}, f32[71,75,71,3]{3,2,1,0}) fusion(fusion.6), kind=kInput, calls=fused_computation.1, backend_config={"kind":"","reification_cost":{"end_to_end_cycles":48033.236820816994}}
  get-tuple-element.13 = f32[71,75,71,3]{3,2,1,0} get-tuple-element(fusion.1), index=1
  get-tuple-element.12 = f32[378075,3]{1,0} get-tuple-element(fusion.1), index=0
  custom-call = f32[378075,3]{1,0} custom-call(get-tuple-element.12, constant_109), custom_call_target="__cublas$gemm", metadata={op_name="jit(run)/jit(main)/dot_general[dimension_numbers=(((4,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/code/oom_gh_example.py" source_line=60}, backend_config={"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
  fusion = f32[3540,64,71]{2,1,0} fusion(custom-call, get-tuple-element.13), kind=kInput, calls=fused_computation, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/code/oom_gh_example.py" source_line=89}, backend_config={"kind":"","reification_cost":{"end_to_end_cycles":107929808.01011539}}
  ROOT reduce.159 = f32[3540,71]{1,0} reduce(fusion, constant_112), dimensions={1}, to_apply=region_2.133, frontend_attributes={fingerprint_before_lhs="418aed85a9188471e15ed64b1856add3"}, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/code/oom_gh_example.py" source_line=89}
} // main.160

jax==0.4.16, xla_client._version=194, module_0000.jit_run.sm_8.9_gpu_after_optimizations.txt

HloModule jit_run, is_scheduled=true, entry_computation_layout={()->f32[3540,71]{1,0}}, allow_spmd_sharding_propagation_to_output={true}

region_2.133 {
  Arg_1.135 = f32[] parameter(1)
  Arg_0.134 = f32[] parameter(0)
  ROOT add.15 = f32[] add(Arg_0.134, Arg_1.135), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
}

fused_computation {
  constant_161 = f32[] constant(1)
  broadcast.219 = f32[71,1]{1,0} broadcast(constant_161), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71) broadcast_dimensions=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=74}
  iota.23 = s32[71,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/iota[dtype=int32 shape=(71,) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=77}
  gather.14 = f32[71,1,1]{2,1,0} gather(broadcast.219, iota.23), offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2), collapsed_slice_dims=(), start_index_map=(1,)) slice_sizes=(1, 1) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
  constant_90 = f32[] constant(0.00052185118)
  broadcast.144 = f32[71,1,1]{2,1,0} broadcast(constant_90), dimensions={}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  multiply.32 = f32[71,1,1]{2,1,0} multiply(gather.14, broadcast.144), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  bitcast.164 = f32[71]{0} bitcast(multiply.32)
  broadcast.143 = f32[3540,71,1,71,75]{4,3,2,1,0} broadcast(bitcast.164), dimensions={3}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  param_1.26 = f32[3540,71,1,71,75]{4,3,2,1,0} parameter(1)
  constant_89 = f32[] constant(-0.5)
  broadcast.142 = f32[3540,71,1,71,75]{4,3,2,1,0} broadcast(constant_89), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  multiply.31 = f32[3540,71,1,71,75]{4,3,2,1,0} multiply(param_1.26, broadcast.142), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  exponential.3 = f32[3540,71,1,71,75]{4,3,2,1,0} exponential(multiply.31), metadata={op_name="jit(run)/jit(main)/exp" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  param_0.15 = f32[3540,71,1,71,75]{4,3,2,1,0} parameter(0)
  multiply.30 = f32[3540,71,1,71,75]{4,3,2,1,0} multiply(param_0.15, broadcast.142), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
  exponential.2 = f32[3540,71,1,71,75]{4,3,2,1,0} exponential(multiply.30), metadata={op_name="jit(run)/jit(main)/exp" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
  add.21 = f32[3540,71,1,71,75]{4,3,2,1,0} add(exponential.3, exponential.2), metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  multiply.29 = f32[3540,71,1,71,75]{4,3,2,1,0} multiply(broadcast.143, add.21), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  bitcast.163 = f32[3540,71,5325]{2,1,0} bitcast(multiply.29)
  constant_96 = f32[] constant(0)
  ROOT reduce.9 = f32[3540,71]{1,0} reduce(bitcast.163, constant_96), dimensions={2}, to_apply=region_2.133, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=89}
} // fused_computation

fused_computation.7 {
  constant_158 = f32[] constant(1)
  broadcast.173.clone.1 = f32[1,71,75,11,3]{4,3,2,1,0} broadcast(constant_158), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 11, 3) broadcast_dimensions=(1, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
  broadcast.172.clone.1 = f32[1,71,75,71,3]{4,3,2,1,0} broadcast(constant_158), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
  concatenate.12.clone.1 = f32[1,71,75,164,3]{4,3,2,1,0} concatenate(broadcast.173.clone.1, broadcast.172.clone.1, broadcast.173.clone.1, broadcast.172.clone.1), dimensions={3}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=3]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  bitcast.178.clone.1 = f32[71,164,3,1,75]{2,1,4,0,3} bitcast(concatenate.12.clone.1)
  iota.9.clone.1 = s32[71,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant_157 = s32[] constant(82)
  broadcast.171.clone.1 = s32[71]{0} broadcast(constant_157), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  iota.8.clone.1 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/iota[dtype=int32 shape=(71,) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=77}
  subtract.13.clone.1 = s32[71]{0} subtract(broadcast.171.clone.1, iota.8.clone.1), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant_152 = s32[] constant(0)
  broadcast.170.clone.1 = s32[71]{0} broadcast(constant_152), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  compare.20.clone.1 = pred[71]{0} compare(subtract.13.clone.1, broadcast.170.clone.1), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant_155 = s32[] constant(246)
  broadcast.169.clone.1 = s32[71]{0} broadcast(constant_155), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  subtract.12.clone.1 = s32[71]{0} subtract(broadcast.169.clone.1, iota.8.clone.1), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  select.11.clone.1 = s32[71]{0} select(compare.20.clone.1, subtract.12.clone.1, subtract.13.clone.1), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/select_n" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  bitcast.177.clone.1 = s32[71,1]{1,0} bitcast(select.11.clone.1)
  broadcast.168.clone.1 = s32[71,1]{1,0} broadcast(constant_152), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(1,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  concatenate.11.clone.1 = s32[71,3]{1,0} concatenate(iota.9.clone.1, bitcast.177.clone.1, broadcast.168.clone.1), dimensions={1}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=1]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  gather.6.clone.1 = f32[71,1,82,3,1,75]{5,4,3,2,1,0} gather(bitcast.178.clone.1, concatenate.11.clone.1), offset_dims={1,2,3,4,5}, collapsed_slice_dims={}, start_index_map={0,1,2}, index_vector_dim=1, slice_sizes={1,82,3,1,75}, indices_are_sorted=true, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2, 3, 4), collapsed_slice_dims=(1,), start_index_map=(1, 3, 4)) slice_sizes=(1, 1, 75, 82, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  bitcast.174.clone.1 = f32[1,71,82,3,75]{4,3,2,1,0} bitcast(gather.6.clone.1)
  slice.23.clone.1 = f32[1,71,71,3,75]{4,3,2,1,0} slice(bitcast.174.clone.1), slice={[0:1], [0:71], [0:71], [0:3], [0:75]}, metadata={op_name="jit(run)/jit(main)/slice[start_indices=(0, 0, 0, 0, 0) limit_indices=(1, 71, 75, 71, 3) strides=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=16}
  slice.27 = f32[1,71,71,1,75]{4,3,2,1,0} slice(bitcast.174.clone.1), slice={[0:1], [0:71], [0:71], [2:3], [0:75]}, metadata={op_name="jit(run)/jit(main)/slice[start_indices=(0, 0, 0, 0, 2) limit_indices=(1, 71, 75, 71, 3) strides=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  constant_154 = f32[] constant(-0.03)
  broadcast.211 = f32[1,71,71,1,75]{4,3,2,1,0} broadcast(constant_154), dimensions={}
  add.33 = f32[1,71,71,1,75]{4,3,2,1,0} add(slice.27, broadcast.211), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  abs.6 = f32[1,71,71,1,75]{4,3,2,1,0} abs(add.33), metadata={op_name="jit(run)/jit(main)/abs" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  constant_153 = f32[] constant(0.03)
  broadcast.210 = f32[1,71,71,1,75]{4,3,2,1,0} broadcast(constant_153), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  add.32 = f32[1,71,71,1,75]{4,3,2,1,0} add(abs.6, broadcast.210), metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  constant_151 = s32[] constant(2)
  dynamic-update-slice.8 = f32[1,71,71,3,75]{4,3,2,1,0} dynamic-update-slice(slice.23.clone.1, add.32, constant_152, constant_152, constant_152, /*index=5*/constant_151, constant_152), metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  transpose.20 = f32[1,71,75,71,3]{4,3,2,1,0} transpose(dynamic-update-slice.8), dimensions={0,1,4,2,3}, metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  ROOT tuple.5 = (f32[1,71,75,71,3]{4,3,2,1,0}, f32[1,71,71,3,75]{4,3,2,1,0}) tuple(transpose.20, slice.23.clone.1)
} // fused_computation.7

fused_computation.1.clone.clone {
  param_0.37 = f32[378075,3]{1,0} parameter(0)
  constant_101.1 = f32[] constant(-1)
  broadcast.240 = f32[378075,3]{1,0} broadcast(constant_101.1), dimensions={}
  add.42 = f32[378075,3]{1,0} add(param_0.37, broadcast.240), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
  multiply.42 = f32[378075,3]{1,0} multiply(add.42, add.42), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
  constant_102.1 = f32[] constant(0.189035922)
  broadcast.239 = f32[378075,3]{1,0} broadcast(constant_102.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  multiply.41 = f32[378075,3]{1,0} multiply(multiply.42, broadcast.239), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
  bitcast.206 = f32[71,75,71,3]{3,2,1,0} bitcast(multiply.41)
  broadcast.238 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(bitcast.206), dimensions={2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
  bitcast.205 = f32[3540,71,75,71,3]{4,3,2,1,0} bitcast(broadcast.238)
  constant_103.1 = f32[] constant(0)
  reduce.14 = f32[3540,71,75,71]{3,2,1,0} reduce(bitcast.205, constant_103.1), dimensions={4}, to_apply=region_2.133
  bitcast.204 = f32[1,3540,71,75,71]{4,3,2,1,0} bitcast(reduce.14)
  ROOT transpose.26 = f32[3540,71,1,71,75]{4,3,2,1,0} transpose(bitcast.204), dimensions={1,4,0,2,3}, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
} // fused_computation.1.clone.clone

fused_computation.2.clone.clone {
  param_0.38 = f32[1,71,71,3,75]{4,3,2,1,0} parameter(0)
  constant_141.1 = f32[] constant(1)
  broadcast.251 = f32[1,71,75,11,3]{4,3,2,1,0} broadcast(constant_141.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 11, 3) broadcast_dimensions=(1, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
  broadcast.250 = f32[1,71,75,71,3]{4,3,2,1,0} broadcast(constant_141.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
  concatenate.28 = f32[1,71,75,164,3]{4,3,2,1,0} concatenate(broadcast.251, broadcast.250, broadcast.251, broadcast.250), dimensions={3}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=3]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  bitcast.212 = f32[71,164,3,1,75]{2,1,4,0,3} bitcast(concatenate.28)
  iota.27 = s32[71,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant_140.1 = s32[] constant(82)
  broadcast.249 = s32[71]{0} broadcast(constant_140.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  iota.26 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/iota[dtype=int32 shape=(71,) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=77}
  subtract.29 = s32[71]{0} subtract(broadcast.249, iota.26), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant_135.1 = s32[] constant(0)
  broadcast.248 = s32[71]{0} broadcast(constant_135.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  compare.28 = pred[71]{0} compare(subtract.29, broadcast.248), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant_138.1 = s32[] constant(246)
  broadcast.247 = s32[71]{0} broadcast(constant_138.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  subtract.28 = s32[71]{0} subtract(broadcast.247, iota.26), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  select.19 = s32[71]{0} select(compare.28, subtract.28, subtract.29), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/select_n" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  bitcast.211 = s32[71,1]{1,0} bitcast(select.19)
  broadcast.246 = s32[71,1]{1,0} broadcast(constant_135.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(1,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  concatenate.27 = s32[71,3]{1,0} concatenate(iota.27, bitcast.211, broadcast.246), dimensions={1}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=1]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  gather.16 = f32[71,1,82,3,1,75]{5,4,3,2,1,0} gather(bitcast.212, concatenate.27), offset_dims={1,2,3,4,5}, collapsed_slice_dims={}, start_index_map={0,1,2}, index_vector_dim=1, slice_sizes={1,82,3,1,75}, indices_are_sorted=true, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2, 3, 4), collapsed_slice_dims=(1,), start_index_map=(1, 3, 4)) slice_sizes=(1, 1, 75, 82, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  bitcast.210 = f32[1,71,82,3,75]{4,3,2,1,0} bitcast(gather.16)
  slice.31 = f32[1,71,71,1,75]{4,3,2,1,0} slice(bitcast.210), slice={[0:1], [0:71], [0:71], [2:3], [0:75]}, metadata={op_name="jit(run)/jit(main)/slice[start_indices=(0, 0, 0, 0, 2) limit_indices=(1, 71, 75, 71, 3) strides=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  constant_137.1 = f32[] constant(-0.03)
  broadcast.245 = f32[1,71,71,1,75]{4,3,2,1,0} broadcast(constant_137.1), dimensions={}
  add.45 = f32[1,71,71,1,75]{4,3,2,1,0} add(slice.31, broadcast.245), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  abs.10 = f32[1,71,71,1,75]{4,3,2,1,0} abs(add.45), metadata={op_name="jit(run)/jit(main)/abs" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  constant_136.1 = f32[] constant(0.03)
  broadcast.244 = f32[1,71,71,1,75]{4,3,2,1,0} broadcast(constant_136.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  add.44 = f32[1,71,71,1,75]{4,3,2,1,0} add(abs.10, broadcast.244), metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  constant_134.1 = s32[] constant(2)
  dynamic-update-slice.12 = f32[1,71,71,3,75]{4,3,2,1,0} dynamic-update-slice(param_0.38, add.44, constant_135.1, constant_135.1, constant_135.1, /*index=5*/constant_134.1, constant_135.1), metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  constant_98.1 = f32[] constant(-1)
  broadcast.243 = f32[1,71,71,3,75]{4,3,2,1,0} broadcast(constant_98.1), dimensions={}
  add.43 = f32[1,71,71,3,75]{4,3,2,1,0} add(dynamic-update-slice.12, broadcast.243), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
  multiply.44 = f32[1,71,71,3,75]{4,3,2,1,0} multiply(add.43, add.43), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  constant_99.1 = f32[] constant(0.189035922)
  broadcast.242 = f32[1,71,71,3,75]{4,3,2,1,0} broadcast(constant_99.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  multiply.43 = f32[1,71,71,3,75]{4,3,2,1,0} multiply(multiply.44, broadcast.242), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  bitcast.209 = f32[71,71,3,75]{3,2,1,0} bitcast(multiply.43)
  transpose.28 = f32[71,75,71,3]{3,2,1,0} transpose(bitcast.209), dimensions={0,3,1,2}
  broadcast.241 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(transpose.28), dimensions={2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
  bitcast.208 = f32[3540,71,75,71,3]{4,3,2,1,0} bitcast(broadcast.241)
  constant_100.1 = f32[] constant(0)
  reduce.15 = f32[3540,71,75,71]{3,2,1,0} reduce(bitcast.208, constant_100.1), dimensions={4}, to_apply=region_2.133
  bitcast.207 = f32[1,3540,71,75,71]{4,3,2,1,0} bitcast(reduce.15)
  ROOT transpose.27 = f32[3540,71,1,71,75]{4,3,2,1,0} transpose(bitcast.207), dimensions={1,4,0,2,3}, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
} // fused_computation.2.clone.clone

fused_computation.7.clone.clone.clone {
  constant_158.2 = f32[] constant(1)
  broadcast.173.clone.4 = f32[1,71,75,11,3]{4,3,2,1,0} broadcast(constant_158.2), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 11, 3) broadcast_dimensions=(1, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
  broadcast.172.clone.4 = f32[1,71,75,71,3]{4,3,2,1,0} broadcast(constant_158.2), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
  concatenate.12.clone.4 = f32[1,71,75,164,3]{4,3,2,1,0} concatenate(broadcast.173.clone.4, broadcast.172.clone.4, broadcast.173.clone.4, broadcast.172.clone.4), dimensions={3}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=3]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  bitcast.178.clone.4 = f32[71,164,3,1,75]{2,1,4,0,3} bitcast(concatenate.12.clone.4)
  iota.9.clone.4 = s32[71,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant_157.2 = s32[] constant(82)
  broadcast.171.clone.4 = s32[71]{0} broadcast(constant_157.2), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  iota.8.clone.4 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/iota[dtype=int32 shape=(71,) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=77}
  subtract.13.clone.4 = s32[71]{0} subtract(broadcast.171.clone.4, iota.8.clone.4), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant_152.2 = s32[] constant(0)
  broadcast.170.clone.4 = s32[71]{0} broadcast(constant_152.2), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  compare.20.clone.4 = pred[71]{0} compare(subtract.13.clone.4, broadcast.170.clone.4), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant_155.2 = s32[] constant(246)
  broadcast.169.clone.4 = s32[71]{0} broadcast(constant_155.2), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  subtract.12.clone.4 = s32[71]{0} subtract(broadcast.169.clone.4, iota.8.clone.4), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  select.11.clone.4 = s32[71]{0} select(compare.20.clone.4, subtract.12.clone.4, subtract.13.clone.4), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/select_n" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  bitcast.177.clone.4 = s32[71,1]{1,0} bitcast(select.11.clone.4)
  broadcast.168.clone.4 = s32[71,1]{1,0} broadcast(constant_152.2), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(1,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  concatenate.11.clone.4 = s32[71,3]{1,0} concatenate(iota.9.clone.4, bitcast.177.clone.4, broadcast.168.clone.4), dimensions={1}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=1]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  gather.6.clone.4 = f32[71,1,82,3,1,75]{5,4,3,2,1,0} gather(bitcast.178.clone.4, concatenate.11.clone.4), offset_dims={1,2,3,4,5}, collapsed_slice_dims={}, start_index_map={0,1,2}, index_vector_dim=1, slice_sizes={1,82,3,1,75}, indices_are_sorted=true, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2, 3, 4), collapsed_slice_dims=(1,), start_index_map=(1, 3, 4)) slice_sizes=(1, 1, 75, 82, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  bitcast.174.clone.4 = f32[1,71,82,3,75]{4,3,2,1,0} bitcast(gather.6.clone.4)
  slice.23.clone.4 = f32[1,71,71,3,75]{4,3,2,1,0} slice(bitcast.174.clone.4), slice={[0:1], [0:71], [0:71], [0:3], [0:75]}, metadata={op_name="jit(run)/jit(main)/slice[start_indices=(0, 0, 0, 0, 0) limit_indices=(1, 71, 75, 71, 3) strides=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=16}
  slice.32 = f32[1,71,71,1,75]{4,3,2,1,0} slice(bitcast.174.clone.4), slice={[0:1], [0:71], [0:71], [2:3], [0:75]}, metadata={op_name="jit(run)/jit(main)/slice[start_indices=(0, 0, 0, 0, 2) limit_indices=(1, 71, 75, 71, 3) strides=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  constant_154.2 = f32[] constant(-0.03)
  broadcast.253 = f32[1,71,71,1,75]{4,3,2,1,0} broadcast(constant_154.2), dimensions={}
  add.47 = f32[1,71,71,1,75]{4,3,2,1,0} add(slice.32, broadcast.253), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  abs.11 = f32[1,71,71,1,75]{4,3,2,1,0} abs(add.47), metadata={op_name="jit(run)/jit(main)/abs" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  constant_153.2 = f32[] constant(0.03)
  broadcast.252 = f32[1,71,71,1,75]{4,3,2,1,0} broadcast(constant_153.2), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  add.46 = f32[1,71,71,1,75]{4,3,2,1,0} add(abs.11, broadcast.252), metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  constant_151.2 = s32[] constant(2)
  dynamic-update-slice.13 = f32[1,71,71,3,75]{4,3,2,1,0} dynamic-update-slice(slice.23.clone.4, add.46, constant_152.2, constant_152.2, constant_152.2, /*index=5*/constant_151.2, constant_152.2), metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  transpose.29 = f32[1,71,75,71,3]{4,3,2,1,0} transpose(dynamic-update-slice.13), dimensions={0,1,4,2,3}, metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  ROOT tuple.8 = (f32[1,71,75,71,3]{4,3,2,1,0}, f32[1,71,71,3,75]{4,3,2,1,0}) tuple(transpose.29, slice.23.clone.4)
} // fused_computation.7.clone.clone.clone

ENTRY main.160 {
  constant_19 = f32[3,3]{1,0} constant({ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, -1 } })
  fusion.7 = (f32[1,71,75,71,3]{4,3,2,1,0}, f32[1,71,71,3,75]{4,3,2,1,0}) fusion(), kind=kInput, calls=fused_computation.7, metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  get-tuple-element.12 = f32[1,71,75,71,3]{4,3,2,1,0} get-tuple-element(fusion.7), index=0
  bitcast.108 = f32[378075,3]{1,0} bitcast(get-tuple-element.12)
  custom-call = f32[378075,3]{1,0} custom-call(bitcast.108, constant_19), custom_call_target="__cublas$gemm", metadata={op_name="jit(run)/jit(main)/dot_general[dimension_numbers=(((4,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=60}, backend_config={"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
  fusion.1.remat2 = f32[3540,71,1,71,75]{4,3,2,1,0} fusion(custom-call), kind=kInput, calls=fused_computation.1.clone.clone, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
  fusion.7.remat3 = (f32[1,71,75,71,3]{4,3,2,1,0}, f32[1,71,71,3,75]{4,3,2,1,0}) fusion(), kind=kInput, calls=fused_computation.7.clone.clone.clone, metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  get-tuple-element.16 = f32[1,71,71,3,75]{4,3,2,1,0} get-tuple-element(fusion.7.remat3), index=1
  fusion.2.remat2 = f32[3540,71,1,71,75]{4,3,2,1,0} fusion(get-tuple-element.16), kind=kInput, calls=fused_computation.2.clone.clone, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  ROOT fusion = f32[3540,71]{1,0} fusion(fusion.1.remat2, fusion.2.remat2), kind=kInput, calls=fused_computation, frontend_attributes={fingerprint_before_lhs="6bc146ca46cc950c369df2c14ca58b4b"}, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=89}
} // main.160
mooskagh commented 1 month ago

Could you also share the before_optimization module?

Also, what GPU are you running it on?

pwithams commented 1 month ago

Locally I have an NVIDIA GeForce RTX 4050 card. Also seeing the issue running on AWS on g4dn instances with 1 or more NVIDIA T4 GPUs.

These outputs are with an nvidia driver with cuda12.3, but upgrading to the latest jax version (0.4.31) with cuda12.6 appears to show the same behavior.

before_optimizations.txt

jax==0.4.14, xla_client._version=174,module_0000.jit_run.before_optimizations.txt

HloModule jit_run, entry_computation_layout={()->f32[3540,71]{1,0}}, allow_spmd_sharding_propagation_to_output={true}

atleast_2d.30 {
  ROOT Arg_0.31 = f32[71,75,11,3]{3,2,1,0} parameter(0)
}

atleast_2d_0.33 {
  ROOT Arg_0.34 = f32[1,71,3]{2,1,0} parameter(0)
}

_where.39 {
  Arg_0.40 = pred[] parameter(0)
  Arg_1.41 = s32[] parameter(1)
  Arg_2.42 = s32[] parameter(2)
  ROOT select.43 = s32[] select(Arg_0.40, Arg_1.41, Arg_2.42), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/jit(_where)/select_n" source_file="/code/oom_gh_example.py" source_line=15}
}

remainder.44 {
  Arg_0.45 = s32[71]{0} parameter(0)
  Arg_1.46 = s32[] parameter(1)
  constant.50 = s32[] constant(0)
  compare.51 = pred[] compare(Arg_1.46, constant.50), direction=EQ, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/eq" source_file="/code/oom_gh_example.py" source_line=15}
  constant.49 = s32[] constant(1)
  call.52 = s32[] call(compare.51, constant.49, Arg_1.46), to_apply=_where.39
  broadcast.53 = s32[71]{0} broadcast(call.52), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/rem" source_file="/code/oom_gh_example.py" source_line=15}
  remainder.54 = s32[71]{0} remainder(Arg_0.45, broadcast.53), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/rem" source_file="/code/oom_gh_example.py" source_line=15}
  constant.47 = s32[] constant(0)
  broadcast.48 = s32[71]{0} broadcast(constant.47), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/code/oom_gh_example.py" source_line=15}
  compare.56 = pred[71]{0} compare(remainder.54, broadcast.48), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/lt" source_file="/code/oom_gh_example.py" source_line=15}
  compare.57 = pred[] compare(call.52, constant.50), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/lt" source_file="/code/oom_gh_example.py" source_line=15}
  broadcast.58 = pred[71]{0} broadcast(compare.57), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/code/oom_gh_example.py" source_line=15}
  compare.59 = pred[71]{0} compare(compare.56, broadcast.58), direction=NE, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/code/oom_gh_example.py" source_line=15}
  compare.55 = pred[71]{0} compare(remainder.54, broadcast.48), direction=NE, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/code/oom_gh_example.py" source_line=15}
  and.60 = pred[71]{0} and(compare.59, compare.55), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/and" source_file="/code/oom_gh_example.py" source_line=15}
  broadcast.61 = s32[71]{0} broadcast(call.52), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/add" source_file="/code/oom_gh_example.py" source_line=15}
  add.62 = s32[71]{0} add(remainder.54, broadcast.61), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/add" source_file="/code/oom_gh_example.py" source_line=15}
  ROOT select.63 = s32[71]{0} select(and.60, add.62, remainder.54), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/select_n" source_file="/code/oom_gh_example.py" source_line=15}
} // remainder.44

_roll_dynamic.64 {
  Arg_0.65 = f32[1,71,75,82,3]{4,3,2,1,0} parameter(0)
  concatenate.77 = f32[1,71,75,164,3]{4,3,2,1,0} concatenate(Arg_0.65, Arg_0.65), dimensions={3}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=3]" source_file="/code/oom_gh_example.py" source_line=15}
  iota.83 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/code/oom_gh_example.py" source_line=15}
  reshape.84 = s32[71,1]{1,0} reshape(iota.83), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/code/oom_gh_example.py" source_line=15}
  constant.73 = s32[] constant(82)
  broadcast.74 = s32[71]{0} broadcast(constant.73), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/code/oom_gh_example.py" source_line=15}
  Arg_1.66 = s32[71]{0} parameter(1)
  constant.75 = s32[] constant(82)
  call.76 = s32[71]{0} call(Arg_1.66, constant.75), to_apply=remainder.44
  subtract.78 = s32[71]{0} subtract(broadcast.74, call.76), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/code/oom_gh_example.py" source_line=15}
  constant.71 = s32[] constant(0)
  broadcast.72 = s32[71]{0} broadcast(constant.71), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/code/oom_gh_example.py" source_line=15}
  compare.79 = pred[71]{0} compare(subtract.78, broadcast.72), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/code/oom_gh_example.py" source_line=15}
  constant.69 = s32[] constant(164)
  broadcast.70 = s32[71]{0} broadcast(constant.69), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/code/oom_gh_example.py" source_line=15}
  add.80 = s32[71]{0} add(subtract.78, broadcast.70), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/code/oom_gh_example.py" source_line=15}
  select.81 = s32[71]{0} select(compare.79, add.80, subtract.78), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/select_n" source_file="/code/oom_gh_example.py" source_line=15}
  reshape.82 = s32[71,1]{1,0} reshape(select.81), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(0,)]" source_file="/code/oom_gh_example.py" source_line=15}
  constant.67 = s32[] constant(0)
  broadcast.68 = s32[71,1]{1,0} broadcast(constant.67), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(1,)]" source_file="/code/oom_gh_example.py" source_line=15}
  concatenate.85 = s32[71,3]{1,0} concatenate(reshape.84, reshape.82, broadcast.68), dimensions={1}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=1]" source_file="/code/oom_gh_example.py" source_line=15}
  ROOT gather.86 = f32[1,71,75,82,3]{4,3,2,1,0} gather(concatenate.77, concatenate.85), offset_dims={0,2,3,4}, collapsed_slice_dims={1}, start_index_map={1,3,4}, index_vector_dim=1, slice_sizes={1,1,75,82,3}, indices_are_sorted=true, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2, 3, 4), collapsed_slice_dims=(1,), start_index_map=(1, 3, 4)) slice_sizes=(1, 1, 75, 82, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=15}
} // _roll_dynamic.64

region_0.94 {
  Arg_0.95 = f32[] parameter(0)
  ROOT Arg_1.96 = f32[] parameter(1)
}

region_1.116 {
  Arg_0.117 = f32[] parameter(0)
  Arg_1.118 = f32[] parameter(1)
  ROOT multiply.119 = f32[] multiply(Arg_0.117, Arg_1.118), metadata={op_name="jit(run)/jit(main)/reduce_prod[axes=(2,)]" source_file="/code/oom_gh_example.py" source_line=47}
}

region_2.133 {
  Arg_0.134 = f32[] parameter(0)
  Arg_1.135 = f32[] parameter(1)
  ROOT add.136 = f32[] add(Arg_0.134, Arg_1.135), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/code/oom_gh_example.py" source_line=49}
}

region_3.142 {
  Arg_0.143 = f32[] parameter(0)
  Arg_1.144 = f32[] parameter(1)
  ROOT add.145 = f32[] add(Arg_0.143, Arg_1.144), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/code/oom_gh_example.py" source_line=50}
}

region_4.155 {
  Arg_0.156 = f32[] parameter(0)
  Arg_1.157 = f32[] parameter(1)
  ROOT add.158 = f32[] add(Arg_0.156, Arg_1.157), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/code/oom_gh_example.py" source_line=89}
}

ENTRY main.160 {
  constant.25 = f32[] constant(1)
  broadcast.26 = f32[1,71]{1,0} broadcast(constant.25), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71) broadcast_dimensions=()]" source_file="/code/oom_gh_example.py" source_line=74}
  iota.29 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/iota[dtype=int32 shape=(71,) dimension=0]" source_file="/code/oom_gh_example.py" source_line=77}
  constant.15 = s32[] constant(0)
  broadcast.16 = s32[71]{0} broadcast(constant.15), dimensions={}, metadata={op_name="jit(run)/jit(main)/lt" source_file="/code/oom_gh_example.py" source_line=63}
  compare.109 = pred[71]{0} compare(iota.29, broadcast.16), direction=LT, metadata={op_name="jit(run)/jit(main)/lt" source_file="/code/oom_gh_example.py" source_line=63}
  constant.13 = s32[] constant(71)
  broadcast.14 = s32[71]{0} broadcast(constant.13), dimensions={}, metadata={op_name="jit(run)/jit(main)/add" source_file="/code/oom_gh_example.py" source_line=63}
  add.110 = s32[71]{0} add(iota.29, broadcast.14), metadata={op_name="jit(run)/jit(main)/add" source_file="/code/oom_gh_example.py" source_line=63}
  select.111 = s32[71]{0} select(compare.109, add.110, iota.29), metadata={op_name="jit(run)/jit(main)/select_n" source_file="/code/oom_gh_example.py" source_line=63}
  reshape.112 = s32[71,1]{1,0} reshape(select.111), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(0,)]" source_file="/code/oom_gh_example.py" source_line=63}
  gather.113 = f32[1,71,1]{2,1,0} gather(broadcast.26, reshape.112), offset_dims={0,2}, collapsed_slice_dims={}, start_index_map={1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2), collapsed_slice_dims=(), start_index_map=(1,)) slice_sizes=(1, 1) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=63}
  reshape.114 = f32[1,71]{1,0} reshape(gather.113), metadata={op_name="jit(run)/jit(main)/squeeze[dimensions=(2,)]" source_file="/code/oom_gh_example.py" source_line=63}
  constant.7 = f32[] constant(0.1)
  broadcast.8 = f32[1,71]{1,0} broadcast(constant.7), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
  multiply.115 = f32[1,71]{1,0} multiply(reshape.114, broadcast.8), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
  reshape.123 = f32[1,71,1]{2,1,0} reshape(multiply.115), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 1) broadcast_dimensions=(0, 1)]" source_file="/code/oom_gh_example.py" source_line=45}
  broadcast.124 = f32[1,71,1]{2,1,0} broadcast(reshape.123), dimensions={0,1,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
  reshape.125 = f32[1,71]{1,0} reshape(broadcast.124), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
  broadcast.126 = f32[1,71,71]{2,1,0} broadcast(reshape.125), dimensions={0,1}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
  constant.9 = f32[] constant(2.3)
  broadcast.10 = f32[1,71,3]{2,1,0} broadcast(constant.9), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=64}
  constant.28 = f32[] constant(1)
  reduce.120 = f32[1,71]{1,0} reduce(broadcast.10, constant.28), dimensions={2}, to_apply=region_1.116, metadata={op_name="jit(run)/jit(main)/reduce_prod[axes=(2,)]" source_file="/code/oom_gh_example.py" source_line=47}
  constant.5 = f32[] constant(15.7496099)
  broadcast.6 = f32[1,71]{1,0} broadcast(constant.5), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=47}
  multiply.121 = f32[1,71]{1,0} multiply(reduce.120, broadcast.6), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=47}
  reshape.122 = f32[1,1,71]{2,1,0} reshape(multiply.121), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71) broadcast_dimensions=(0, 2)]" source_file="/code/oom_gh_example.py" source_line=45}
  broadcast.127 = f32[1,1,71]{2,1,0} broadcast(reshape.122), dimensions={0,1,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
  reshape.128 = f32[1,71]{1,0} reshape(broadcast.127), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
  broadcast.129 = f32[1,71,71]{2,1,0} broadcast(reshape.128), dimensions={0,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
  divide.130 = f32[1,71,71]{2,1,0} divide(broadcast.126, broadcast.129), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
  reshape.150 = f32[1,1,71,1,71]{4,3,2,1,0} reshape(divide.130), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 1, 71) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/code/oom_gh_example.py" source_line=45}
  broadcast.151 = f32[1,1,71,1,71]{4,3,2,1,0} broadcast(reshape.150), dimensions={0,1,2,3,4}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
  reshape.152 = f32[1,71,71]{2,1,0} reshape(broadcast.151), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
  broadcast.153 = f32[1,3540,71,75,71]{4,3,2,1,0} broadcast(reshape.152), dimensions={0,2,4}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
  constant.23 = f32[] constant(1)
  broadcast.24 = f32[71,75,11,3]{3,2,1,0} broadcast(constant.23), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=8}
  call.32 = f32[71,75,11,3]{3,2,1,0} call(broadcast.24), to_apply=atleast_2d.30
  reshape.37 = f32[1,71,75,11,3]{4,3,2,1,0} reshape(call.32), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 11, 3) broadcast_dimensions=(1, 2, 3, 4)]" source_file="/code/oom_gh_example.py" source_line=9}
  constant.11 = f32[] constant(1)
  broadcast.12 = f32[1,71,3]{2,1,0} broadcast(constant.11), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 3) broadcast_dimensions=(1, 2)]" source_file="/code/oom_gh_example.py" source_line=12}
  call.35 = f32[1,71,3]{2,1,0} call(broadcast.12), to_apply=atleast_2d_0.33
  broadcast.36 = f32[1,71,75,71,3]{4,3,2,1,0} broadcast(call.35), dimensions={0,3,4}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/code/oom_gh_example.py" source_line=9}
  concatenate.38 = f32[1,71,75,82,3]{4,3,2,1,0} concatenate(reshape.37, broadcast.36), dimensions={3}, metadata={op_name="jit(run)/jit(main)/concatenate[dimension=3]" source_file="/code/oom_gh_example.py" source_line=9}
  call.87 = f32[1,71,75,82,3]{4,3,2,1,0} call(concatenate.38, iota.29), to_apply=_roll_dynamic.64
  slice.88 = f32[1,71,75,71,3]{4,3,2,1,0} slice(call.87), slice={[0:1], [0:71], [0:75], [0:71], [0:3]}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2, 3, 4), collapsed_slice_dims=(), start_index_map=(3, 4)) slice_sizes=(1, 71, 75, 71, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=16}
  constant.20 = s32[1]{0} constant({2})
  slice.89 = f32[1,71,75,71,1]{4,3,2,1,0} slice(slice.88), slice={[0:1], [0:71], [0:75], [0:71], [2:3]}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2, 3, 4), collapsed_slice_dims=(), start_index_map=(3, 4)) slice_sizes=(1, 71, 75, 71, 1) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=17}
  reshape.90 = f32[1,71,75,71]{3,2,1,0} reshape(slice.89), metadata={op_name="jit(run)/jit(main)/squeeze[dimensions=(4,)]" source_file="/code/oom_gh_example.py" source_line=17}
  constant.21 = f32[] constant(0.03)
  broadcast.22 = f32[1,71,75,71]{3,2,1,0} broadcast(constant.21), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=17}
  subtract.91 = f32[1,71,75,71]{3,2,1,0} subtract(reshape.90, broadcast.22), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=17}
  abs.92 = f32[1,71,75,71]{3,2,1,0} abs(subtract.91), metadata={op_name="jit(run)/jit(main)/abs" source_file="/code/oom_gh_example.py" source_line=17}
  add.93 = f32[1,71,75,71]{3,2,1,0} add(abs.92, broadcast.22), metadata={op_name="jit(run)/jit(main)/add" source_file="/code/oom_gh_example.py" source_line=17}
  scatter.97 = f32[1,71,75,71,3]{4,3,2,1,0} scatter(slice.88, constant.20, add.93), update_window_dims={0,1,2,3}, inserted_window_dims={4}, scatter_dims_to_operand_dims={4}, index_vector_dim=0, indices_are_sorted=true, unique_indices=true, to_apply=region_0.94, metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/code/oom_gh_example.py" source_line=17}
  reshape.98 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} reshape(scatter.97), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4, 5)]" source_file="/code/oom_gh_example.py" source_line=58}
  broadcast.99 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.98), dimensions={0,1,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
  reshape.100 = f32[1,71,75,71,3]{4,3,2,1,0} reshape(broadcast.99), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
  broadcast.101 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.100), dimensions={0,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
  constant.17 = f32[] constant(1)
  broadcast.18 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(constant.17), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
  subtract.102 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} subtract(broadcast.101, broadcast.18), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
  multiply.131 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} multiply(subtract.102, subtract.102), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=49}
  constant.3 = f32[] constant(5.29)
  broadcast.4 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(constant.3), dimensions={}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=49}
  divide.132 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} divide(multiply.131, broadcast.4), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=49}
  constant.27 = f32[] constant(0)
  reduce.137 = f32[1,3540,71,75,71]{4,3,2,1,0} reduce(divide.132, constant.27), dimensions={5}, to_apply=region_2.133, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/code/oom_gh_example.py" source_line=49}
  constant.1 = f32[] constant(-0.5)
  broadcast.2 = f32[1,3540,71,75,71]{4,3,2,1,0} broadcast(constant.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=49}
  multiply.138 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(reduce.137, broadcast.2), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=49}
  exponential.139 = f32[1,3540,71,75,71]{4,3,2,1,0} exponential(multiply.138), metadata={op_name="jit(run)/jit(main)/exp" source_file="/code/oom_gh_example.py" source_line=49}
  constant.19 = f32[3,3]{1,0} constant({ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, -1 } })
  dot.103 = f32[1,71,75,71,3]{4,3,2,1,0} dot(scatter.97, constant.19), lhs_contracting_dims={4}, rhs_contracting_dims={0}, metadata={op_name="jit(run)/jit(main)/dot_general[dimension_numbers=(((4,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/code/oom_gh_example.py" source_line=60}
  reshape.104 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} reshape(dot.103), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4, 5)]" source_file="/code/oom_gh_example.py" source_line=61}
  broadcast.105 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.104), dimensions={0,1,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=61}
  reshape.106 = f32[1,71,75,71,3]{4,3,2,1,0} reshape(broadcast.105), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=61}
  broadcast.107 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.106), dimensions={0,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=61}
  subtract.108 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} subtract(broadcast.107, broadcast.18), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=61}
  multiply.140 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} multiply(subtract.108, subtract.108), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=50}
  divide.141 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} divide(multiply.140, broadcast.4), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=50}
  reduce.146 = f32[1,3540,71,75,71]{4,3,2,1,0} reduce(divide.141, constant.27), dimensions={5}, to_apply=region_3.142, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/code/oom_gh_example.py" source_line=50}
  multiply.147 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(reduce.146, broadcast.2), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=50}
  exponential.148 = f32[1,3540,71,75,71]{4,3,2,1,0} exponential(multiply.147), metadata={op_name="jit(run)/jit(main)/exp" source_file="/code/oom_gh_example.py" source_line=50}
  add.149 = f32[1,3540,71,75,71]{4,3,2,1,0} add(exponential.139, exponential.148), metadata={op_name="jit(run)/jit(main)/add" source_file="/code/oom_gh_example.py" source_line=49}
  multiply.154 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(broadcast.153, add.149), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
  ROOT reduce.159 = f32[3540,71]{1,0} reduce(multiply.154, constant.27), dimensions={0,2,3}, to_apply=region_4.155, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/code/oom_gh_example.py" source_line=89}
} // main.160

jax==0.4.16, xla_client._version=194,module_0000.jit_run.before_optimizations.txt

HloModule jit_run, entry_computation_layout={()->f32[3540,71]{1,0}}, allow_spmd_sharding_propagation_to_output={true}

atleast_2d.30 {
  ROOT Arg_0.31 = f32[71,75,11,3]{3,2,1,0} parameter(0)
}

atleast_2d_0.33 {
  ROOT Arg_0.34 = f32[1,71,3]{2,1,0} parameter(0)
}

_where.39 {
  Arg_0.40 = pred[] parameter(0)
  Arg_1.41 = s32[] parameter(1)
  Arg_2.42 = s32[] parameter(2)
  ROOT select.43 = s32[] select(Arg_0.40, Arg_1.41, Arg_2.42), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/jit(_where)/select_n" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
}

remainder.44 {
  Arg_0.45 = s32[71]{0} parameter(0)
  Arg_1.46 = s32[] parameter(1)
  constant.50 = s32[] constant(0)
  compare.51 = pred[] compare(Arg_1.46, constant.50), direction=EQ, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/eq" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant.49 = s32[] constant(1)
  call.52 = s32[] call(compare.51, constant.49, Arg_1.46), to_apply=_where.39
  broadcast.53 = s32[71]{0} broadcast(call.52), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/rem" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  remainder.54 = s32[71]{0} remainder(Arg_0.45, broadcast.53), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/rem" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant.47 = s32[] constant(0)
  broadcast.48 = s32[71]{0} broadcast(constant.47), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  compare.56 = pred[71]{0} compare(remainder.54, broadcast.48), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  compare.57 = pred[] compare(call.52, constant.50), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  broadcast.58 = pred[71]{0} broadcast(compare.57), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  compare.59 = pred[71]{0} compare(compare.56, broadcast.58), direction=NE, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  compare.55 = pred[71]{0} compare(remainder.54, broadcast.48), direction=NE, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  and.60 = pred[71]{0} and(compare.59, compare.55), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/and" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  broadcast.61 = s32[71]{0} broadcast(call.52), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  add.62 = s32[71]{0} add(remainder.54, broadcast.61), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  ROOT select.63 = s32[71]{0} select(and.60, add.62, remainder.54), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/select_n" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
} // remainder.44

_roll_dynamic.64 {
  Arg_0.65 = f32[1,71,75,82,3]{4,3,2,1,0} parameter(0)
  concatenate.77 = f32[1,71,75,164,3]{4,3,2,1,0} concatenate(Arg_0.65, Arg_0.65), dimensions={3}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=3]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  iota.83 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  reshape.84 = s32[71,1]{1,0} reshape(iota.83), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant.73 = s32[] constant(82)
  broadcast.74 = s32[71]{0} broadcast(constant.73), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  Arg_1.66 = s32[71]{0} parameter(1)
  constant.75 = s32[] constant(82)
  call.76 = s32[71]{0} call(Arg_1.66, constant.75), to_apply=remainder.44
  subtract.78 = s32[71]{0} subtract(broadcast.74, call.76), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant.71 = s32[] constant(0)
  broadcast.72 = s32[71]{0} broadcast(constant.71), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  compare.79 = pred[71]{0} compare(subtract.78, broadcast.72), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant.69 = s32[] constant(164)
  broadcast.70 = s32[71]{0} broadcast(constant.69), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  add.80 = s32[71]{0} add(subtract.78, broadcast.70), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  select.81 = s32[71]{0} select(compare.79, add.80, subtract.78), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/select_n" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  reshape.82 = s32[71,1]{1,0} reshape(select.81), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(0,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  constant.67 = s32[] constant(0)
  broadcast.68 = s32[71,1]{1,0} broadcast(constant.67), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(1,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  concatenate.85 = s32[71,3]{1,0} concatenate(reshape.84, reshape.82, broadcast.68), dimensions={1}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=1]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
  ROOT gather.86 = f32[1,71,75,82,3]{4,3,2,1,0} gather(concatenate.77, concatenate.85), offset_dims={0,2,3,4}, collapsed_slice_dims={1}, start_index_map={1,3,4}, index_vector_dim=1, slice_sizes={1,1,75,82,3}, indices_are_sorted=true, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2, 3, 4), collapsed_slice_dims=(1,), start_index_map=(1, 3, 4)) slice_sizes=(1, 1, 75, 82, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
} // _roll_dynamic.64

region_0.94 {
  Arg_0.95 = f32[] parameter(0)
  ROOT Arg_1.96 = f32[] parameter(1)
}

region_1.116 {
  Arg_0.117 = f32[] parameter(0)
  Arg_1.118 = f32[] parameter(1)
  ROOT multiply.119 = f32[] multiply(Arg_0.117, Arg_1.118), metadata={op_name="jit(run)/jit(main)/reduce_prod[axes=(2,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=47}
}

region_2.133 {
  Arg_0.134 = f32[] parameter(0)
  Arg_1.135 = f32[] parameter(1)
  ROOT add.136 = f32[] add(Arg_0.134, Arg_1.135), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
}

region_3.142 {
  Arg_0.143 = f32[] parameter(0)
  Arg_1.144 = f32[] parameter(1)
  ROOT add.145 = f32[] add(Arg_0.143, Arg_1.144), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
}

region_4.155 {
  Arg_0.156 = f32[] parameter(0)
  Arg_1.157 = f32[] parameter(1)
  ROOT add.158 = f32[] add(Arg_0.156, Arg_1.157), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=89}
}

ENTRY main.160 {
  constant.25 = f32[] constant(1)
  broadcast.26 = f32[1,71]{1,0} broadcast(constant.25), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71) broadcast_dimensions=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=74}
  iota.29 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/iota[dtype=int32 shape=(71,) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=77}
  constant.15 = s32[] constant(0)
  broadcast.16 = s32[71]{0} broadcast(constant.15), dimensions={}, metadata={op_name="jit(run)/jit(main)/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
  compare.109 = pred[71]{0} compare(iota.29, broadcast.16), direction=LT, metadata={op_name="jit(run)/jit(main)/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
  constant.13 = s32[] constant(71)
  broadcast.14 = s32[71]{0} broadcast(constant.13), dimensions={}, metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
  add.110 = s32[71]{0} add(iota.29, broadcast.14), metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
  select.111 = s32[71]{0} select(compare.109, add.110, iota.29), metadata={op_name="jit(run)/jit(main)/select_n" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
  reshape.112 = s32[71,1]{1,0} reshape(select.111), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(0,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
  gather.113 = f32[1,71,1]{2,1,0} gather(broadcast.26, reshape.112), offset_dims={0,2}, collapsed_slice_dims={}, start_index_map={1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2), collapsed_slice_dims=(), start_index_map=(1,)) slice_sizes=(1, 1) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
  reshape.114 = f32[1,71]{1,0} reshape(gather.113), metadata={op_name="jit(run)/jit(main)/squeeze[dimensions=(2,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
  constant.7 = f32[] constant(0.1)
  broadcast.8 = f32[1,71]{1,0} broadcast(constant.7), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  multiply.115 = f32[1,71]{1,0} multiply(reshape.114, broadcast.8), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  reshape.123 = f32[1,71,1]{2,1,0} reshape(multiply.115), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 1) broadcast_dimensions=(0, 1)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  broadcast.124 = f32[1,71,1]{2,1,0} broadcast(reshape.123), dimensions={0,1,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  reshape.125 = f32[1,71]{1,0} reshape(broadcast.124), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  broadcast.126 = f32[1,71,71]{2,1,0} broadcast(reshape.125), dimensions={0,1}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  constant.9 = f32[] constant(2.3)
  broadcast.10 = f32[1,71,3]{2,1,0} broadcast(constant.9), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=64}
  constant.28 = f32[] constant(1)
  reduce.120 = f32[1,71]{1,0} reduce(broadcast.10, constant.28), dimensions={2}, to_apply=region_1.116, metadata={op_name="jit(run)/jit(main)/reduce_prod[axes=(2,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=47}
  constant.5 = f32[] constant(15.7496099)
  broadcast.6 = f32[1,71]{1,0} broadcast(constant.5), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=47}
  multiply.121 = f32[1,71]{1,0} multiply(reduce.120, broadcast.6), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=47}
  reshape.122 = f32[1,1,71]{2,1,0} reshape(multiply.121), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71) broadcast_dimensions=(0, 2)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  broadcast.127 = f32[1,1,71]{2,1,0} broadcast(reshape.122), dimensions={0,1,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  reshape.128 = f32[1,71]{1,0} reshape(broadcast.127), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  broadcast.129 = f32[1,71,71]{2,1,0} broadcast(reshape.128), dimensions={0,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  divide.130 = f32[1,71,71]{2,1,0} divide(broadcast.126, broadcast.129), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  reshape.150 = f32[1,1,71,1,71]{4,3,2,1,0} reshape(divide.130), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 1, 71) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  broadcast.151 = f32[1,1,71,1,71]{4,3,2,1,0} broadcast(reshape.150), dimensions={0,1,2,3,4}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  reshape.152 = f32[1,71,71]{2,1,0} reshape(broadcast.151), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  broadcast.153 = f32[1,3540,71,75,71]{4,3,2,1,0} broadcast(reshape.152), dimensions={0,2,4}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  constant.23 = f32[] constant(1)
  broadcast.24 = f32[71,75,11,3]{3,2,1,0} broadcast(constant.23), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=8}
  call.32 = f32[71,75,11,3]{3,2,1,0} call(broadcast.24), to_apply=atleast_2d.30
  reshape.37 = f32[1,71,75,11,3]{4,3,2,1,0} reshape(call.32), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 11, 3) broadcast_dimensions=(1, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
  constant.11 = f32[] constant(1)
  broadcast.12 = f32[1,71,3]{2,1,0} broadcast(constant.11), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 3) broadcast_dimensions=(1, 2)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=12}
  call.35 = f32[1,71,3]{2,1,0} call(broadcast.12), to_apply=atleast_2d_0.33
  broadcast.36 = f32[1,71,75,71,3]{4,3,2,1,0} broadcast(call.35), dimensions={0,3,4}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
  concatenate.38 = f32[1,71,75,82,3]{4,3,2,1,0} concatenate(reshape.37, broadcast.36), dimensions={3}, metadata={op_name="jit(run)/jit(main)/concatenate[dimension=3]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
  call.87 = f32[1,71,75,82,3]{4,3,2,1,0} call(concatenate.38, iota.29), to_apply=_roll_dynamic.64
  slice.88 = f32[1,71,75,71,3]{4,3,2,1,0} slice(call.87), slice={[0:1], [0:71], [0:75], [0:71], [0:3]}, metadata={op_name="jit(run)/jit(main)/slice[start_indices=(0, 0, 0, 0, 0) limit_indices=(1, 71, 75, 71, 3) strides=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=16}
  constant.20 = s32[1]{0} constant({2})
  slice.89 = f32[1,71,75,71,1]{4,3,2,1,0} slice(slice.88), slice={[0:1], [0:71], [0:75], [0:71], [2:3]}, metadata={op_name="jit(run)/jit(main)/slice[start_indices=(0, 0, 0, 0, 2) limit_indices=(1, 71, 75, 71, 3) strides=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  reshape.90 = f32[1,71,75,71]{3,2,1,0} reshape(slice.89), metadata={op_name="jit(run)/jit(main)/squeeze[dimensions=(4,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  constant.21 = f32[] constant(0.03)
  broadcast.22 = f32[1,71,75,71]{3,2,1,0} broadcast(constant.21), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  subtract.91 = f32[1,71,75,71]{3,2,1,0} subtract(reshape.90, broadcast.22), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  abs.92 = f32[1,71,75,71]{3,2,1,0} abs(subtract.91), metadata={op_name="jit(run)/jit(main)/abs" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  add.93 = f32[1,71,75,71]{3,2,1,0} add(abs.92, broadcast.22), metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  scatter.97 = f32[1,71,75,71,3]{4,3,2,1,0} scatter(slice.88, constant.20, add.93), update_window_dims={0,1,2,3}, inserted_window_dims={4}, scatter_dims_to_operand_dims={4}, index_vector_dim=0, indices_are_sorted=true, unique_indices=true, to_apply=region_0.94, metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
  reshape.98 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} reshape(scatter.97), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4, 5)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
  broadcast.99 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.98), dimensions={0,1,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
  reshape.100 = f32[1,71,75,71,3]{4,3,2,1,0} reshape(broadcast.99), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
  broadcast.101 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.100), dimensions={0,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
  constant.17 = f32[] constant(1)
  broadcast.18 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(constant.17), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
  subtract.102 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} subtract(broadcast.101, broadcast.18), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
  multiply.131 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} multiply(subtract.102, subtract.102), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  constant.3 = f32[] constant(5.29)
  broadcast.4 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(constant.3), dimensions={}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  divide.132 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} divide(multiply.131, broadcast.4), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  constant.27 = f32[] constant(0)
  reduce.137 = f32[1,3540,71,75,71]{4,3,2,1,0} reduce(divide.132, constant.27), dimensions={5}, to_apply=region_2.133, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  constant.1 = f32[] constant(-0.5)
  broadcast.2 = f32[1,3540,71,75,71]{4,3,2,1,0} broadcast(constant.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  multiply.138 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(reduce.137, broadcast.2), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  exponential.139 = f32[1,3540,71,75,71]{4,3,2,1,0} exponential(multiply.138), metadata={op_name="jit(run)/jit(main)/exp" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  constant.19 = f32[3,3]{1,0} constant({ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, -1 } })
  dot.103 = f32[1,71,75,71,3]{4,3,2,1,0} dot(scatter.97, constant.19), lhs_contracting_dims={4}, rhs_contracting_dims={0}, metadata={op_name="jit(run)/jit(main)/dot_general[dimension_numbers=(((4,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=60}
  reshape.104 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} reshape(dot.103), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4, 5)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
  broadcast.105 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.104), dimensions={0,1,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
  reshape.106 = f32[1,71,75,71,3]{4,3,2,1,0} reshape(broadcast.105), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
  broadcast.107 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.106), dimensions={0,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
  subtract.108 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} subtract(broadcast.107, broadcast.18), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
  multiply.140 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} multiply(subtract.108, subtract.108), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
  divide.141 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} divide(multiply.140, broadcast.4), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
  reduce.146 = f32[1,3540,71,75,71]{4,3,2,1,0} reduce(divide.141, constant.27), dimensions={5}, to_apply=region_3.142, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
  multiply.147 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(reduce.146, broadcast.2), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
  exponential.148 = f32[1,3540,71,75,71]{4,3,2,1,0} exponential(multiply.147), metadata={op_name="jit(run)/jit(main)/exp" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
  add.149 = f32[1,3540,71,75,71]{4,3,2,1,0} add(exponential.139, exponential.148), metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
  multiply.154 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(broadcast.153, add.149), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
  ROOT reduce.159 = f32[3540,71]{1,0} reduce(multiply.154, constant.27), dimensions={0,2,3}, to_apply=region_4.155, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=89}
} // main.160
mooskagh commented 1 month ago

Oh that's quite old versions, apparently it changed between July and November 2023 -- lots has changed since then.

Not sure whether bisection is doable or useful.

Could you confirm that it doesn't work with the most recent JAX? (0.4.32) (no need to confirm, I'm able to repro it)

mooskagh commented 1 month ago

There are instructions that output f32[3540,71,75,71] tensor, which is what is taking 5G of VRAM.

Older version emits everything as a single kernel (see the graph below), while the newer one splits it into three (causing the materialization of that tensor).

Generally, when partitioning graph into kernels, the compiler doesn't try to optimize for a memory usage, only for a runtime. However, I doubt that materializing 5G twice would make it faster.

image

ghostway0 commented 1 month ago

hello! it might've not fused them because #8170 got merged this year (i.e. after 326f72f)

pwithams commented 1 month ago

Is there a way to recreate this as an XLA test locally? I'm happy to try myself. This could then be run against a few versions/commits to try narrow down/verify what changed caused it.

A flag to disable whatever optimization is causing this would also work - is XLA able to take in account availability of memory when optimizing? Typically individual GPUs don't have that much memory - I also experimented with trying to pmap it over multiple GPUs but got the same sized memory error I seem to remember.

Obviously in this particular case I'd prefer it to be slightly less time-optimized in order to run/fit on a given GPU.

mooskagh commented 1 month ago

It won't help to know the commit that changed the behavior. The code that builds fusions was completely rewritten since then anyway, so we'd need to look whether it's possible to fix the current code so that it does it.
I'll try to take a look this week.

I would also be curious to look into the before_optimizations HLO of the shape shape=(1590, 3) variant which appears to have different fusion decision. (just one version is enough, actually before_optimization HLOs are identical).

pwithams commented 1 month ago

Ok that makes sense, thanks!

I generated an XLA dump for the latest version of Jax (v0.4.33) and both before and after optimization files were the same for shape=(1594, 3) and shape=(1595, 3), despite 1595 giving an OOM and the other passing.

I looked at the buffer assignment files and both are allocating close to 5GB, but the one that works is 4.49GB and the one that fails is 4.5GB, so I assume the 1594/1595 threshold is just the limit of what can be allocated locally, but no change in actual fusion behaviour.

The jax docs (https://jax.readthedocs.io/en/latest/device_memory_profiling.html) indicate that jax.profiler.save_device_memory_profile can be used with pprof to profile GPU memory usage. However, when I use that with the working 1594 shape example it says the program only uses ~500kB, despite the buffer assignment saying it allocated ~5GB. I don't think I can profile the failing one as it crashes.

Not sure if I'm misinterpreting the save_device_memory_profile results but could that suggest memory is being allocated but not used? I initially assumed that maybe it wasn't actually showing GPU memory but the docs do seem to say "The JAX device memory profiler allows us to explore how and why JAX programs are using GPU or TPU memory".

I've attached the memory profile visualize with pprof for the working 1594 example that allocates 4.49 of GPU memory. Let me know if there's any other info I can provide.

profile001