Open pwithams opened 1 month ago
Could you also share the before_optimization
module?
Also, what GPU are you running it on?
Locally I have an NVIDIA GeForce RTX 4050 card. Also seeing the issue running on AWS on g4dn instances with 1 or more NVIDIA T4 GPUs.
These outputs are with an nvidia driver with cuda12.3, but upgrading to the latest jax version (0.4.31) with cuda12.6 appears to show the same behavior.
jax==0.4.14, xla_client._version=174,module_0000.jit_run.before_optimizations.txt
HloModule jit_run, entry_computation_layout={()->f32[3540,71]{1,0}}, allow_spmd_sharding_propagation_to_output={true}
atleast_2d.30 {
ROOT Arg_0.31 = f32[71,75,11,3]{3,2,1,0} parameter(0)
}
atleast_2d_0.33 {
ROOT Arg_0.34 = f32[1,71,3]{2,1,0} parameter(0)
}
_where.39 {
Arg_0.40 = pred[] parameter(0)
Arg_1.41 = s32[] parameter(1)
Arg_2.42 = s32[] parameter(2)
ROOT select.43 = s32[] select(Arg_0.40, Arg_1.41, Arg_2.42), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/jit(_where)/select_n" source_file="/code/oom_gh_example.py" source_line=15}
}
remainder.44 {
Arg_0.45 = s32[71]{0} parameter(0)
Arg_1.46 = s32[] parameter(1)
constant.50 = s32[] constant(0)
compare.51 = pred[] compare(Arg_1.46, constant.50), direction=EQ, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/eq" source_file="/code/oom_gh_example.py" source_line=15}
constant.49 = s32[] constant(1)
call.52 = s32[] call(compare.51, constant.49, Arg_1.46), to_apply=_where.39
broadcast.53 = s32[71]{0} broadcast(call.52), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/rem" source_file="/code/oom_gh_example.py" source_line=15}
remainder.54 = s32[71]{0} remainder(Arg_0.45, broadcast.53), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/rem" source_file="/code/oom_gh_example.py" source_line=15}
constant.47 = s32[] constant(0)
broadcast.48 = s32[71]{0} broadcast(constant.47), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/code/oom_gh_example.py" source_line=15}
compare.56 = pred[71]{0} compare(remainder.54, broadcast.48), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/lt" source_file="/code/oom_gh_example.py" source_line=15}
compare.57 = pred[] compare(call.52, constant.50), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/lt" source_file="/code/oom_gh_example.py" source_line=15}
broadcast.58 = pred[71]{0} broadcast(compare.57), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/code/oom_gh_example.py" source_line=15}
compare.59 = pred[71]{0} compare(compare.56, broadcast.58), direction=NE, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/code/oom_gh_example.py" source_line=15}
compare.55 = pred[71]{0} compare(remainder.54, broadcast.48), direction=NE, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/code/oom_gh_example.py" source_line=15}
and.60 = pred[71]{0} and(compare.59, compare.55), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/and" source_file="/code/oom_gh_example.py" source_line=15}
broadcast.61 = s32[71]{0} broadcast(call.52), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/add" source_file="/code/oom_gh_example.py" source_line=15}
add.62 = s32[71]{0} add(remainder.54, broadcast.61), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/add" source_file="/code/oom_gh_example.py" source_line=15}
ROOT select.63 = s32[71]{0} select(and.60, add.62, remainder.54), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/select_n" source_file="/code/oom_gh_example.py" source_line=15}
} // remainder.44
_roll_dynamic.64 {
Arg_0.65 = f32[1,71,75,82,3]{4,3,2,1,0} parameter(0)
concatenate.77 = f32[1,71,75,164,3]{4,3,2,1,0} concatenate(Arg_0.65, Arg_0.65), dimensions={3}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=3]" source_file="/code/oom_gh_example.py" source_line=15}
iota.83 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/code/oom_gh_example.py" source_line=15}
reshape.84 = s32[71,1]{1,0} reshape(iota.83), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/code/oom_gh_example.py" source_line=15}
constant.73 = s32[] constant(82)
broadcast.74 = s32[71]{0} broadcast(constant.73), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/code/oom_gh_example.py" source_line=15}
Arg_1.66 = s32[71]{0} parameter(1)
constant.75 = s32[] constant(82)
call.76 = s32[71]{0} call(Arg_1.66, constant.75), to_apply=remainder.44
subtract.78 = s32[71]{0} subtract(broadcast.74, call.76), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/code/oom_gh_example.py" source_line=15}
constant.71 = s32[] constant(0)
broadcast.72 = s32[71]{0} broadcast(constant.71), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/code/oom_gh_example.py" source_line=15}
compare.79 = pred[71]{0} compare(subtract.78, broadcast.72), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/code/oom_gh_example.py" source_line=15}
constant.69 = s32[] constant(164)
broadcast.70 = s32[71]{0} broadcast(constant.69), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/code/oom_gh_example.py" source_line=15}
add.80 = s32[71]{0} add(subtract.78, broadcast.70), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/code/oom_gh_example.py" source_line=15}
select.81 = s32[71]{0} select(compare.79, add.80, subtract.78), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/select_n" source_file="/code/oom_gh_example.py" source_line=15}
reshape.82 = s32[71,1]{1,0} reshape(select.81), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(0,)]" source_file="/code/oom_gh_example.py" source_line=15}
constant.67 = s32[] constant(0)
broadcast.68 = s32[71,1]{1,0} broadcast(constant.67), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(1,)]" source_file="/code/oom_gh_example.py" source_line=15}
concatenate.85 = s32[71,3]{1,0} concatenate(reshape.84, reshape.82, broadcast.68), dimensions={1}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=1]" source_file="/code/oom_gh_example.py" source_line=15}
ROOT gather.86 = f32[1,71,75,82,3]{4,3,2,1,0} gather(concatenate.77, concatenate.85), offset_dims={0,2,3,4}, collapsed_slice_dims={1}, start_index_map={1,3,4}, index_vector_dim=1, slice_sizes={1,1,75,82,3}, indices_are_sorted=true, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2, 3, 4), collapsed_slice_dims=(1,), start_index_map=(1, 3, 4)) slice_sizes=(1, 1, 75, 82, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=15}
} // _roll_dynamic.64
region_0.94 {
Arg_0.95 = f32[] parameter(0)
ROOT Arg_1.96 = f32[] parameter(1)
}
region_1.116 {
Arg_0.117 = f32[] parameter(0)
Arg_1.118 = f32[] parameter(1)
ROOT multiply.119 = f32[] multiply(Arg_0.117, Arg_1.118), metadata={op_name="jit(run)/jit(main)/reduce_prod[axes=(2,)]" source_file="/code/oom_gh_example.py" source_line=47}
}
region_2.133 {
Arg_0.134 = f32[] parameter(0)
Arg_1.135 = f32[] parameter(1)
ROOT add.136 = f32[] add(Arg_0.134, Arg_1.135), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/code/oom_gh_example.py" source_line=49}
}
region_3.142 {
Arg_0.143 = f32[] parameter(0)
Arg_1.144 = f32[] parameter(1)
ROOT add.145 = f32[] add(Arg_0.143, Arg_1.144), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/code/oom_gh_example.py" source_line=50}
}
region_4.155 {
Arg_0.156 = f32[] parameter(0)
Arg_1.157 = f32[] parameter(1)
ROOT add.158 = f32[] add(Arg_0.156, Arg_1.157), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/code/oom_gh_example.py" source_line=89}
}
ENTRY main.160 {
constant.25 = f32[] constant(1)
broadcast.26 = f32[1,71]{1,0} broadcast(constant.25), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71) broadcast_dimensions=()]" source_file="/code/oom_gh_example.py" source_line=74}
iota.29 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/iota[dtype=int32 shape=(71,) dimension=0]" source_file="/code/oom_gh_example.py" source_line=77}
constant.15 = s32[] constant(0)
broadcast.16 = s32[71]{0} broadcast(constant.15), dimensions={}, metadata={op_name="jit(run)/jit(main)/lt" source_file="/code/oom_gh_example.py" source_line=63}
compare.109 = pred[71]{0} compare(iota.29, broadcast.16), direction=LT, metadata={op_name="jit(run)/jit(main)/lt" source_file="/code/oom_gh_example.py" source_line=63}
constant.13 = s32[] constant(71)
broadcast.14 = s32[71]{0} broadcast(constant.13), dimensions={}, metadata={op_name="jit(run)/jit(main)/add" source_file="/code/oom_gh_example.py" source_line=63}
add.110 = s32[71]{0} add(iota.29, broadcast.14), metadata={op_name="jit(run)/jit(main)/add" source_file="/code/oom_gh_example.py" source_line=63}
select.111 = s32[71]{0} select(compare.109, add.110, iota.29), metadata={op_name="jit(run)/jit(main)/select_n" source_file="/code/oom_gh_example.py" source_line=63}
reshape.112 = s32[71,1]{1,0} reshape(select.111), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(0,)]" source_file="/code/oom_gh_example.py" source_line=63}
gather.113 = f32[1,71,1]{2,1,0} gather(broadcast.26, reshape.112), offset_dims={0,2}, collapsed_slice_dims={}, start_index_map={1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2), collapsed_slice_dims=(), start_index_map=(1,)) slice_sizes=(1, 1) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=63}
reshape.114 = f32[1,71]{1,0} reshape(gather.113), metadata={op_name="jit(run)/jit(main)/squeeze[dimensions=(2,)]" source_file="/code/oom_gh_example.py" source_line=63}
constant.7 = f32[] constant(0.1)
broadcast.8 = f32[1,71]{1,0} broadcast(constant.7), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
multiply.115 = f32[1,71]{1,0} multiply(reshape.114, broadcast.8), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
reshape.123 = f32[1,71,1]{2,1,0} reshape(multiply.115), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 1) broadcast_dimensions=(0, 1)]" source_file="/code/oom_gh_example.py" source_line=45}
broadcast.124 = f32[1,71,1]{2,1,0} broadcast(reshape.123), dimensions={0,1,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
reshape.125 = f32[1,71]{1,0} reshape(broadcast.124), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
broadcast.126 = f32[1,71,71]{2,1,0} broadcast(reshape.125), dimensions={0,1}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
constant.9 = f32[] constant(2.3)
broadcast.10 = f32[1,71,3]{2,1,0} broadcast(constant.9), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=64}
constant.28 = f32[] constant(1)
reduce.120 = f32[1,71]{1,0} reduce(broadcast.10, constant.28), dimensions={2}, to_apply=region_1.116, metadata={op_name="jit(run)/jit(main)/reduce_prod[axes=(2,)]" source_file="/code/oom_gh_example.py" source_line=47}
constant.5 = f32[] constant(15.7496099)
broadcast.6 = f32[1,71]{1,0} broadcast(constant.5), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=47}
multiply.121 = f32[1,71]{1,0} multiply(reduce.120, broadcast.6), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=47}
reshape.122 = f32[1,1,71]{2,1,0} reshape(multiply.121), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71) broadcast_dimensions=(0, 2)]" source_file="/code/oom_gh_example.py" source_line=45}
broadcast.127 = f32[1,1,71]{2,1,0} broadcast(reshape.122), dimensions={0,1,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
reshape.128 = f32[1,71]{1,0} reshape(broadcast.127), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
broadcast.129 = f32[1,71,71]{2,1,0} broadcast(reshape.128), dimensions={0,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
divide.130 = f32[1,71,71]{2,1,0} divide(broadcast.126, broadcast.129), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=45}
reshape.150 = f32[1,1,71,1,71]{4,3,2,1,0} reshape(divide.130), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 1, 71) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/code/oom_gh_example.py" source_line=45}
broadcast.151 = f32[1,1,71,1,71]{4,3,2,1,0} broadcast(reshape.150), dimensions={0,1,2,3,4}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
reshape.152 = f32[1,71,71]{2,1,0} reshape(broadcast.151), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
broadcast.153 = f32[1,3540,71,75,71]{4,3,2,1,0} broadcast(reshape.152), dimensions={0,2,4}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
constant.23 = f32[] constant(1)
broadcast.24 = f32[71,75,11,3]{3,2,1,0} broadcast(constant.23), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=8}
call.32 = f32[71,75,11,3]{3,2,1,0} call(broadcast.24), to_apply=atleast_2d.30
reshape.37 = f32[1,71,75,11,3]{4,3,2,1,0} reshape(call.32), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 11, 3) broadcast_dimensions=(1, 2, 3, 4)]" source_file="/code/oom_gh_example.py" source_line=9}
constant.11 = f32[] constant(1)
broadcast.12 = f32[1,71,3]{2,1,0} broadcast(constant.11), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 3) broadcast_dimensions=(1, 2)]" source_file="/code/oom_gh_example.py" source_line=12}
call.35 = f32[1,71,3]{2,1,0} call(broadcast.12), to_apply=atleast_2d_0.33
broadcast.36 = f32[1,71,75,71,3]{4,3,2,1,0} broadcast(call.35), dimensions={0,3,4}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/code/oom_gh_example.py" source_line=9}
concatenate.38 = f32[1,71,75,82,3]{4,3,2,1,0} concatenate(reshape.37, broadcast.36), dimensions={3}, metadata={op_name="jit(run)/jit(main)/concatenate[dimension=3]" source_file="/code/oom_gh_example.py" source_line=9}
call.87 = f32[1,71,75,82,3]{4,3,2,1,0} call(concatenate.38, iota.29), to_apply=_roll_dynamic.64
slice.88 = f32[1,71,75,71,3]{4,3,2,1,0} slice(call.87), slice={[0:1], [0:71], [0:75], [0:71], [0:3]}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2, 3, 4), collapsed_slice_dims=(), start_index_map=(3, 4)) slice_sizes=(1, 71, 75, 71, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=16}
constant.20 = s32[1]{0} constant({2})
slice.89 = f32[1,71,75,71,1]{4,3,2,1,0} slice(slice.88), slice={[0:1], [0:71], [0:75], [0:71], [2:3]}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2, 3, 4), collapsed_slice_dims=(), start_index_map=(3, 4)) slice_sizes=(1, 71, 75, 71, 1) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/code/oom_gh_example.py" source_line=17}
reshape.90 = f32[1,71,75,71]{3,2,1,0} reshape(slice.89), metadata={op_name="jit(run)/jit(main)/squeeze[dimensions=(4,)]" source_file="/code/oom_gh_example.py" source_line=17}
constant.21 = f32[] constant(0.03)
broadcast.22 = f32[1,71,75,71]{3,2,1,0} broadcast(constant.21), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=17}
subtract.91 = f32[1,71,75,71]{3,2,1,0} subtract(reshape.90, broadcast.22), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=17}
abs.92 = f32[1,71,75,71]{3,2,1,0} abs(subtract.91), metadata={op_name="jit(run)/jit(main)/abs" source_file="/code/oom_gh_example.py" source_line=17}
add.93 = f32[1,71,75,71]{3,2,1,0} add(abs.92, broadcast.22), metadata={op_name="jit(run)/jit(main)/add" source_file="/code/oom_gh_example.py" source_line=17}
scatter.97 = f32[1,71,75,71,3]{4,3,2,1,0} scatter(slice.88, constant.20, add.93), update_window_dims={0,1,2,3}, inserted_window_dims={4}, scatter_dims_to_operand_dims={4}, index_vector_dim=0, indices_are_sorted=true, unique_indices=true, to_apply=region_0.94, metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/code/oom_gh_example.py" source_line=17}
reshape.98 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} reshape(scatter.97), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4, 5)]" source_file="/code/oom_gh_example.py" source_line=58}
broadcast.99 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.98), dimensions={0,1,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
reshape.100 = f32[1,71,75,71,3]{4,3,2,1,0} reshape(broadcast.99), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
broadcast.101 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.100), dimensions={0,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
constant.17 = f32[] constant(1)
broadcast.18 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(constant.17), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
subtract.102 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} subtract(broadcast.101, broadcast.18), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=58}
multiply.131 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} multiply(subtract.102, subtract.102), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=49}
constant.3 = f32[] constant(5.29)
broadcast.4 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(constant.3), dimensions={}, metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=49}
divide.132 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} divide(multiply.131, broadcast.4), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=49}
constant.27 = f32[] constant(0)
reduce.137 = f32[1,3540,71,75,71]{4,3,2,1,0} reduce(divide.132, constant.27), dimensions={5}, to_apply=region_2.133, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/code/oom_gh_example.py" source_line=49}
constant.1 = f32[] constant(-0.5)
broadcast.2 = f32[1,3540,71,75,71]{4,3,2,1,0} broadcast(constant.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=49}
multiply.138 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(reduce.137, broadcast.2), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=49}
exponential.139 = f32[1,3540,71,75,71]{4,3,2,1,0} exponential(multiply.138), metadata={op_name="jit(run)/jit(main)/exp" source_file="/code/oom_gh_example.py" source_line=49}
constant.19 = f32[3,3]{1,0} constant({ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, -1 } })
dot.103 = f32[1,71,75,71,3]{4,3,2,1,0} dot(scatter.97, constant.19), lhs_contracting_dims={4}, rhs_contracting_dims={0}, metadata={op_name="jit(run)/jit(main)/dot_general[dimension_numbers=(((4,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/code/oom_gh_example.py" source_line=60}
reshape.104 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} reshape(dot.103), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4, 5)]" source_file="/code/oom_gh_example.py" source_line=61}
broadcast.105 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.104), dimensions={0,1,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=61}
reshape.106 = f32[1,71,75,71,3]{4,3,2,1,0} reshape(broadcast.105), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=61}
broadcast.107 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.106), dimensions={0,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=61}
subtract.108 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} subtract(broadcast.107, broadcast.18), metadata={op_name="jit(run)/jit(main)/sub" source_file="/code/oom_gh_example.py" source_line=61}
multiply.140 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} multiply(subtract.108, subtract.108), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=50}
divide.141 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} divide(multiply.140, broadcast.4), metadata={op_name="jit(run)/jit(main)/div" source_file="/code/oom_gh_example.py" source_line=50}
reduce.146 = f32[1,3540,71,75,71]{4,3,2,1,0} reduce(divide.141, constant.27), dimensions={5}, to_apply=region_3.142, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/code/oom_gh_example.py" source_line=50}
multiply.147 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(reduce.146, broadcast.2), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=50}
exponential.148 = f32[1,3540,71,75,71]{4,3,2,1,0} exponential(multiply.147), metadata={op_name="jit(run)/jit(main)/exp" source_file="/code/oom_gh_example.py" source_line=50}
add.149 = f32[1,3540,71,75,71]{4,3,2,1,0} add(exponential.139, exponential.148), metadata={op_name="jit(run)/jit(main)/add" source_file="/code/oom_gh_example.py" source_line=49}
multiply.154 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(broadcast.153, add.149), metadata={op_name="jit(run)/jit(main)/mul" source_file="/code/oom_gh_example.py" source_line=45}
ROOT reduce.159 = f32[3540,71]{1,0} reduce(multiply.154, constant.27), dimensions={0,2,3}, to_apply=region_4.155, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/code/oom_gh_example.py" source_line=89}
} // main.160
jax==0.4.16, xla_client._version=194,module_0000.jit_run.before_optimizations.txt
HloModule jit_run, entry_computation_layout={()->f32[3540,71]{1,0}}, allow_spmd_sharding_propagation_to_output={true}
atleast_2d.30 {
ROOT Arg_0.31 = f32[71,75,11,3]{3,2,1,0} parameter(0)
}
atleast_2d_0.33 {
ROOT Arg_0.34 = f32[1,71,3]{2,1,0} parameter(0)
}
_where.39 {
Arg_0.40 = pred[] parameter(0)
Arg_1.41 = s32[] parameter(1)
Arg_2.42 = s32[] parameter(2)
ROOT select.43 = s32[] select(Arg_0.40, Arg_1.41, Arg_2.42), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/jit(_where)/select_n" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
}
remainder.44 {
Arg_0.45 = s32[71]{0} parameter(0)
Arg_1.46 = s32[] parameter(1)
constant.50 = s32[] constant(0)
compare.51 = pred[] compare(Arg_1.46, constant.50), direction=EQ, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/eq" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
constant.49 = s32[] constant(1)
call.52 = s32[] call(compare.51, constant.49, Arg_1.46), to_apply=_where.39
broadcast.53 = s32[71]{0} broadcast(call.52), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/rem" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
remainder.54 = s32[71]{0} remainder(Arg_0.45, broadcast.53), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/rem" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
constant.47 = s32[] constant(0)
broadcast.48 = s32[71]{0} broadcast(constant.47), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
compare.56 = pred[71]{0} compare(remainder.54, broadcast.48), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
compare.57 = pred[] compare(call.52, constant.50), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
broadcast.58 = pred[71]{0} broadcast(compare.57), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
compare.59 = pred[71]{0} compare(compare.56, broadcast.58), direction=NE, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
compare.55 = pred[71]{0} compare(remainder.54, broadcast.48), direction=NE, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/ne" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
and.60 = pred[71]{0} and(compare.59, compare.55), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/and" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
broadcast.61 = s32[71]{0} broadcast(call.52), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
add.62 = s32[71]{0} add(remainder.54, broadcast.61), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
ROOT select.63 = s32[71]{0} select(and.60, add.62, remainder.54), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/jit(remainder)/select_n" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
} // remainder.44
_roll_dynamic.64 {
Arg_0.65 = f32[1,71,75,82,3]{4,3,2,1,0} parameter(0)
concatenate.77 = f32[1,71,75,164,3]{4,3,2,1,0} concatenate(Arg_0.65, Arg_0.65), dimensions={3}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=3]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
iota.83 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
reshape.84 = s32[71,1]{1,0} reshape(iota.83), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/iota[dtype=int32 shape=(71, 1) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
constant.73 = s32[] constant(82)
broadcast.74 = s32[71]{0} broadcast(constant.73), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
Arg_1.66 = s32[71]{0} parameter(1)
constant.75 = s32[] constant(82)
call.76 = s32[71]{0} call(Arg_1.66, constant.75), to_apply=remainder.44
subtract.78 = s32[71]{0} subtract(broadcast.74, call.76), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
constant.71 = s32[] constant(0)
broadcast.72 = s32[71]{0} broadcast(constant.71), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
compare.79 = pred[71]{0} compare(subtract.78, broadcast.72), direction=LT, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
constant.69 = s32[] constant(164)
broadcast.70 = s32[71]{0} broadcast(constant.69), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
add.80 = s32[71]{0} add(subtract.78, broadcast.70), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
select.81 = s32[71]{0} select(compare.79, add.80, subtract.78), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/select_n" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
reshape.82 = s32[71,1]{1,0} reshape(select.81), metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(0,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
constant.67 = s32[] constant(0)
broadcast.68 = s32[71,1]{1,0} broadcast(constant.67), dimensions={}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(1,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
concatenate.85 = s32[71,3]{1,0} concatenate(reshape.84, reshape.82, broadcast.68), dimensions={1}, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/concatenate[dimension=1]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
ROOT gather.86 = f32[1,71,75,82,3]{4,3,2,1,0} gather(concatenate.77, concatenate.85), offset_dims={0,2,3,4}, collapsed_slice_dims={1}, start_index_map={1,3,4}, index_vector_dim=1, slice_sizes={1,1,75,82,3}, indices_are_sorted=true, metadata={op_name="jit(run)/jit(main)/vmap(vmap(vmap(vmap(jit(_roll_dynamic)))))/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2, 3, 4), collapsed_slice_dims=(1,), start_index_map=(1, 3, 4)) slice_sizes=(1, 1, 75, 82, 3) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=15}
} // _roll_dynamic.64
region_0.94 {
Arg_0.95 = f32[] parameter(0)
ROOT Arg_1.96 = f32[] parameter(1)
}
region_1.116 {
Arg_0.117 = f32[] parameter(0)
Arg_1.118 = f32[] parameter(1)
ROOT multiply.119 = f32[] multiply(Arg_0.117, Arg_1.118), metadata={op_name="jit(run)/jit(main)/reduce_prod[axes=(2,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=47}
}
region_2.133 {
Arg_0.134 = f32[] parameter(0)
Arg_1.135 = f32[] parameter(1)
ROOT add.136 = f32[] add(Arg_0.134, Arg_1.135), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
}
region_3.142 {
Arg_0.143 = f32[] parameter(0)
Arg_1.144 = f32[] parameter(1)
ROOT add.145 = f32[] add(Arg_0.143, Arg_1.144), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
}
region_4.155 {
Arg_0.156 = f32[] parameter(0)
Arg_1.157 = f32[] parameter(1)
ROOT add.158 = f32[] add(Arg_0.156, Arg_1.157), metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=89}
}
ENTRY main.160 {
constant.25 = f32[] constant(1)
broadcast.26 = f32[1,71]{1,0} broadcast(constant.25), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71) broadcast_dimensions=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=74}
iota.29 = s32[71]{0} iota(), iota_dimension=0, metadata={op_name="jit(run)/jit(main)/iota[dtype=int32 shape=(71,) dimension=0]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=77}
constant.15 = s32[] constant(0)
broadcast.16 = s32[71]{0} broadcast(constant.15), dimensions={}, metadata={op_name="jit(run)/jit(main)/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
compare.109 = pred[71]{0} compare(iota.29, broadcast.16), direction=LT, metadata={op_name="jit(run)/jit(main)/lt" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
constant.13 = s32[] constant(71)
broadcast.14 = s32[71]{0} broadcast(constant.13), dimensions={}, metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
add.110 = s32[71]{0} add(iota.29, broadcast.14), metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
select.111 = s32[71]{0} select(compare.109, add.110, iota.29), metadata={op_name="jit(run)/jit(main)/select_n" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
reshape.112 = s32[71,1]{1,0} reshape(select.111), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(71, 1) broadcast_dimensions=(0,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
gather.113 = f32[1,71,1]{2,1,0} gather(broadcast.26, reshape.112), offset_dims={0,2}, collapsed_slice_dims={}, start_index_map={1}, index_vector_dim=1, slice_sizes={1,1}, metadata={op_name="jit(run)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 2), collapsed_slice_dims=(), start_index_map=(1,)) slice_sizes=(1, 1) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
reshape.114 = f32[1,71]{1,0} reshape(gather.113), metadata={op_name="jit(run)/jit(main)/squeeze[dimensions=(2,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=63}
constant.7 = f32[] constant(0.1)
broadcast.8 = f32[1,71]{1,0} broadcast(constant.7), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
multiply.115 = f32[1,71]{1,0} multiply(reshape.114, broadcast.8), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
reshape.123 = f32[1,71,1]{2,1,0} reshape(multiply.115), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 1) broadcast_dimensions=(0, 1)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
broadcast.124 = f32[1,71,1]{2,1,0} broadcast(reshape.123), dimensions={0,1,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
reshape.125 = f32[1,71]{1,0} reshape(broadcast.124), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
broadcast.126 = f32[1,71,71]{2,1,0} broadcast(reshape.125), dimensions={0,1}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
constant.9 = f32[] constant(2.3)
broadcast.10 = f32[1,71,3]{2,1,0} broadcast(constant.9), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=64}
constant.28 = f32[] constant(1)
reduce.120 = f32[1,71]{1,0} reduce(broadcast.10, constant.28), dimensions={2}, to_apply=region_1.116, metadata={op_name="jit(run)/jit(main)/reduce_prod[axes=(2,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=47}
constant.5 = f32[] constant(15.7496099)
broadcast.6 = f32[1,71]{1,0} broadcast(constant.5), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=47}
multiply.121 = f32[1,71]{1,0} multiply(reduce.120, broadcast.6), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=47}
reshape.122 = f32[1,1,71]{2,1,0} reshape(multiply.121), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71) broadcast_dimensions=(0, 2)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
broadcast.127 = f32[1,1,71]{2,1,0} broadcast(reshape.122), dimensions={0,1,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
reshape.128 = f32[1,71]{1,0} reshape(broadcast.127), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
broadcast.129 = f32[1,71,71]{2,1,0} broadcast(reshape.128), dimensions={0,2}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
divide.130 = f32[1,71,71]{2,1,0} divide(broadcast.126, broadcast.129), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
reshape.150 = f32[1,1,71,1,71]{4,3,2,1,0} reshape(divide.130), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 1, 71) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
broadcast.151 = f32[1,1,71,1,71]{4,3,2,1,0} broadcast(reshape.150), dimensions={0,1,2,3,4}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
reshape.152 = f32[1,71,71]{2,1,0} reshape(broadcast.151), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
broadcast.153 = f32[1,3540,71,75,71]{4,3,2,1,0} broadcast(reshape.152), dimensions={0,2,4}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
constant.23 = f32[] constant(1)
broadcast.24 = f32[71,75,11,3]{3,2,1,0} broadcast(constant.23), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=8}
call.32 = f32[71,75,11,3]{3,2,1,0} call(broadcast.24), to_apply=atleast_2d.30
reshape.37 = f32[1,71,75,11,3]{4,3,2,1,0} reshape(call.32), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 11, 3) broadcast_dimensions=(1, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
constant.11 = f32[] constant(1)
broadcast.12 = f32[1,71,3]{2,1,0} broadcast(constant.11), dimensions={}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 3) broadcast_dimensions=(1, 2)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=12}
call.35 = f32[1,71,3]{2,1,0} call(broadcast.12), to_apply=atleast_2d_0.33
broadcast.36 = f32[1,71,75,71,3]{4,3,2,1,0} broadcast(call.35), dimensions={0,3,4}, metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
concatenate.38 = f32[1,71,75,82,3]{4,3,2,1,0} concatenate(reshape.37, broadcast.36), dimensions={3}, metadata={op_name="jit(run)/jit(main)/concatenate[dimension=3]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=9}
call.87 = f32[1,71,75,82,3]{4,3,2,1,0} call(concatenate.38, iota.29), to_apply=_roll_dynamic.64
slice.88 = f32[1,71,75,71,3]{4,3,2,1,0} slice(call.87), slice={[0:1], [0:71], [0:75], [0:71], [0:3]}, metadata={op_name="jit(run)/jit(main)/slice[start_indices=(0, 0, 0, 0, 0) limit_indices=(1, 71, 75, 71, 3) strides=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=16}
constant.20 = s32[1]{0} constant({2})
slice.89 = f32[1,71,75,71,1]{4,3,2,1,0} slice(slice.88), slice={[0:1], [0:71], [0:75], [0:71], [2:3]}, metadata={op_name="jit(run)/jit(main)/slice[start_indices=(0, 0, 0, 0, 2) limit_indices=(1, 71, 75, 71, 3) strides=None]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
reshape.90 = f32[1,71,75,71]{3,2,1,0} reshape(slice.89), metadata={op_name="jit(run)/jit(main)/squeeze[dimensions=(4,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
constant.21 = f32[] constant(0.03)
broadcast.22 = f32[1,71,75,71]{3,2,1,0} broadcast(constant.21), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
subtract.91 = f32[1,71,75,71]{3,2,1,0} subtract(reshape.90, broadcast.22), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
abs.92 = f32[1,71,75,71]{3,2,1,0} abs(subtract.91), metadata={op_name="jit(run)/jit(main)/abs" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
add.93 = f32[1,71,75,71]{3,2,1,0} add(abs.92, broadcast.22), metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
scatter.97 = f32[1,71,75,71,3]{4,3,2,1,0} scatter(slice.88, constant.20, add.93), update_window_dims={0,1,2,3}, inserted_window_dims={4}, scatter_dims_to_operand_dims={4}, index_vector_dim=0, indices_are_sorted=true, unique_indices=true, to_apply=region_0.94, metadata={op_name="jit(run)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2, 3), inserted_window_dims=(4,), scatter_dims_to_operand_dims=(4,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=17}
reshape.98 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} reshape(scatter.97), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4, 5)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
broadcast.99 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.98), dimensions={0,1,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
reshape.100 = f32[1,71,75,71,3]{4,3,2,1,0} reshape(broadcast.99), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
broadcast.101 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.100), dimensions={0,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
constant.17 = f32[] constant(1)
broadcast.18 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(constant.17), dimensions={}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
subtract.102 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} subtract(broadcast.101, broadcast.18), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=58}
multiply.131 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} multiply(subtract.102, subtract.102), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
constant.3 = f32[] constant(5.29)
broadcast.4 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(constant.3), dimensions={}, metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
divide.132 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} divide(multiply.131, broadcast.4), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
constant.27 = f32[] constant(0)
reduce.137 = f32[1,3540,71,75,71]{4,3,2,1,0} reduce(divide.132, constant.27), dimensions={5}, to_apply=region_2.133, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
constant.1 = f32[] constant(-0.5)
broadcast.2 = f32[1,3540,71,75,71]{4,3,2,1,0} broadcast(constant.1), dimensions={}, metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
multiply.138 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(reduce.137, broadcast.2), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
exponential.139 = f32[1,3540,71,75,71]{4,3,2,1,0} exponential(multiply.138), metadata={op_name="jit(run)/jit(main)/exp" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
constant.19 = f32[3,3]{1,0} constant({ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, -1 } })
dot.103 = f32[1,71,75,71,3]{4,3,2,1,0} dot(scatter.97, constant.19), lhs_contracting_dims={4}, rhs_contracting_dims={0}, metadata={op_name="jit(run)/jit(main)/dot_general[dimension_numbers=(((4,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=60}
reshape.104 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} reshape(dot.103), metadata={op_name="jit(run)/jit(main)/broadcast_in_dim[shape=(1, 1, 71, 75, 71, 3) broadcast_dimensions=(0, 2, 3, 4, 5)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
broadcast.105 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.104), dimensions={0,1,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
reshape.106 = f32[1,71,75,71,3]{4,3,2,1,0} reshape(broadcast.105), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
broadcast.107 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.106), dimensions={0,2,3,4,5}, metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
subtract.108 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} subtract(broadcast.107, broadcast.18), metadata={op_name="jit(run)/jit(main)/sub" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=61}
multiply.140 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} multiply(subtract.108, subtract.108), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
divide.141 = f32[1,3540,71,75,71,3]{5,4,3,2,1,0} divide(multiply.140, broadcast.4), metadata={op_name="jit(run)/jit(main)/div" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
reduce.146 = f32[1,3540,71,75,71]{4,3,2,1,0} reduce(divide.141, constant.27), dimensions={5}, to_apply=region_3.142, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(5,)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
multiply.147 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(reduce.146, broadcast.2), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
exponential.148 = f32[1,3540,71,75,71]{4,3,2,1,0} exponential(multiply.147), metadata={op_name="jit(run)/jit(main)/exp" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=50}
add.149 = f32[1,3540,71,75,71]{4,3,2,1,0} add(exponential.139, exponential.148), metadata={op_name="jit(run)/jit(main)/add" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=49}
multiply.154 = f32[1,3540,71,75,71]{4,3,2,1,0} multiply(broadcast.153, add.149), metadata={op_name="jit(run)/jit(main)/mul" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=45}
ROOT reduce.159 = f32[3540,71]{1,0} reduce(multiply.154, constant.27), dimensions={0,2,3}, to_apply=region_4.155, metadata={op_name="jit(run)/jit(main)/reduce_sum[axes=(0, 2, 3)]" source_file="/home/qubeds/workspaces/platform/oom_gh_example.py" source_line=89}
} // main.160
Oh that's quite old versions, apparently it changed between July and November 2023 -- lots has changed since then.
Not sure whether bisection is doable or useful.
Could you confirm that it doesn't work with the most recent JAX? (0.4.32) (no need to confirm, I'm able to repro it)
There are instructions that output f32[3540,71,75,71] tensor, which is what is taking 5G of VRAM.
Older version emits everything as a single kernel (see the graph below), while the newer one splits it into three (causing the materialization of that tensor).
Generally, when partitioning graph into kernels, the compiler doesn't try to optimize for a memory usage, only for a runtime. However, I doubt that materializing 5G twice would make it faster.
hello! it might've not fused them because #8170 got merged this year (i.e. after 326f72f)
Is there a way to recreate this as an XLA test locally? I'm happy to try myself. This could then be run against a few versions/commits to try narrow down/verify what changed caused it.
A flag to disable whatever optimization is causing this would also work - is XLA able to take in account availability of memory when optimizing? Typically individual GPUs don't have that much memory - I also experimented with trying to pmap it over multiple GPUs but got the same sized memory error I seem to remember.
Obviously in this particular case I'd prefer it to be slightly less time-optimized in order to run/fit on a given GPU.
It won't help to know the commit that changed the behavior. The code that builds fusions was completely rewritten since then anyway, so we'd need to look whether it's possible to fix the current code so that it does it.
I'll try to take a look this week.
I would also be curious to look into the before_optimizations
HLO of the shape shape=(1590, 3)
variant which appears to have different fusion decision. (just one version is enough, actually before_optimization HLOs are identical).
Ok that makes sense, thanks!
I generated an XLA dump for the latest version of Jax (v0.4.33) and both before and after optimization files were the same for shape=(1594, 3)
and shape=(1595, 3)
, despite 1595 giving an OOM and the other passing.
I looked at the buffer assignment files and both are allocating close to 5GB, but the one that works is 4.49GB and the one that fails is 4.5GB, so I assume the 1594/1595 threshold is just the limit of what can be allocated locally, but no change in actual fusion behaviour.
The jax docs (https://jax.readthedocs.io/en/latest/device_memory_profiling.html) indicate that jax.profiler.save_device_memory_profile
can be used with pprof to profile GPU memory usage. However, when I use that with the working 1594
shape example it says the program only uses ~500kB, despite the buffer assignment saying it allocated ~5GB. I don't think I can profile the failing one as it crashes.
Not sure if I'm misinterpreting the save_device_memory_profile
results but could that suggest memory is being allocated but not used? I initially assumed that maybe it wasn't actually showing GPU memory but the docs do seem to say "The JAX device memory profiler allows us to explore how and why JAX programs are using GPU or TPU memory".
I've attached the memory profile visualize with pprof for the working 1594
example that allocates 4.49 of GPU memory. Let me know if there's any other info I can provide.
This is a more XLA-specific version of https://github.com/google/jax/issues/23548, encountered using an NVIDIA GPU.
Basically when using a value of
dls=jnp.ones(shape=(1590, 3))
the program ran successfully and pprof reported ~500kB of memory usage, but increasing todls=jnp.ones(shape=(1600, 3))
fails trying to allocate ~5GB. So it seems like there might be a difference between request GPU memory and actual usage.The Jax script described in that ticket works on Jax v0.4.14 but not on versions v0.4.16 and above. The python xla_client version for v0.4.14 is 174 and for v0.4.16 is 194, so I'm thinking the reason for this change in behaviour must have occurred between
5ca49a9
and326f72f
. I see there were some changes/additions to cudnn fusion logic between those commits, and the last line ofmodule_0000.jit_run.sm_8.9_gpu_after_optimizations.txt
seems to reference "fusion" rather than "reduce" as before, not sure if that is relevant.I've provided content of the
module_0000.jit_run.sm_8.9_gpu_after_optimizations.txt
files below. Let me know if there's any other information I can provide.xla_dump output
jax==0.4.14, xla_client._version=174, module_0000.jit_run.sm_8.9_gpu_after_optimizations.txt
jax==0.4.16, xla_client._version=194, module_0000.jit_run.sm_8.9_gpu_after_optimizations.txt