Open henryiii opened 1 month ago
Here's the jaxpr for step_fn:
step_fn
{ lambda ; a:f64[10000] b:f64[10000] c:f64[10000] d:f64[10000] e:f64[10000] f:f64[10000] g:f64[10000] h:f64[] i:f64[] j:f64[]. let k:f64[10000] l:f64[10000] m:f64[10000] n:f64[10000] o:f64[10000] p:f64[10000] = pjit[ name=step_fun jaxpr={ lambda ; q:f64[] r:f64[] s:f64[] t:f64[10000] u:f64[10000] v:f64[10000] w:f64[10000] x:f64[10000] y:f64[10000] z:f64[10000]. let ba:f64[10000,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(10000, 1) ] t bb:f64[1,10000] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 10000) ] t bc:f64[10000,10000] = sub ba bb bd:f64[10000,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(10000, 1) ] u be:f64[1,10000] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 10000) ] u bf:f64[10000,10000] = sub bd be bg:f64[10000,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(10000, 1) ] v bh:f64[1,10000] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 10000) ] v bi:f64[10000,10000] = sub bg bh bj:f64[10000,10000] = integer_pow[y=2] bc bk:f64[10000,10000] = integer_pow[y=2] bf bl:f64[10000,10000] = add bj bk bm:f64[10000,10000] = integer_pow[y=2] bi bn:f64[10000,10000] = add bl bm bo:f64[] = integer_pow[y=2] q bp:f64[10000,10000] = add bn bo bq:i64[10000] = iota[dimension=0 dtype=int64 shape=(10000,)] br:bool[10000] = lt bq 0 bs:i64[10000] = add bq 10000 bt:i64[10000] = select_n br bq bs bu:bool[10000] = lt bq 0 bv:i64[10000] = add bq 10000 bw:i64[10000] = select_n bu bq bv bx:i32[10000] = convert_element_type[new_dtype=int32 weak_type=False] bt by:i32[10000] = convert_element_type[new_dtype=int32 weak_type=False] bw bz:i32[10000,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(10000, 1) ] bx ca:i32[10000,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(10000, 1) ] by cb:i32[10000,2] = concatenate[dimension=1] bz ca cc:f64[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) ] 1.0 cd:f64[10000,10000] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] bp cb cc ce:f64[10000,10000] = sqrt cd cf:f64[10000,10000] = mul cd ce cg:f64[] = neg r ch:f64[] = convert_element_type[new_dtype=float64 weak_type=False] cg ci:f64[10000] = mul ch z cj:f64[1,10000] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 10000) ] ci ck:f64[10000,10000] = div cj cf cl:i64[10000] = iota[dimension=0 dtype=int64 shape=(10000,)] cm:bool[10000] = lt cl 0 cn:i64[10000] = add cl 10000 co:i64[10000] = select_n cm cl cn cp:bool[10000] = lt cl 0 cq:i64[10000] = add cl 10000 cr:i64[10000] = select_n cp cl cq cs:i32[10000] = convert_element_type[new_dtype=int32 weak_type=False] co ct:i32[10000] = convert_element_type[new_dtype=int32 weak_type=False] cr cu:i32[10000,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(10000, 1) ] cs cv:i32[10000,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(10000, 1) ] ct cw:i32[10000,2] = concatenate[dimension=1] cu cv cx:f64[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) ] 0.0 cy:f64[10000,10000] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] ck cw cx cz:f64[10000,10000] = mul cy bc da:f64[10000] = reduce_sum[axes=(1,)] cz db:f64[10000,10000] = mul cy bf dc:f64[10000] = reduce_sum[axes=(1,)] db dd:f64[10000,10000] = mul cy bi de:f64[10000] = reduce_sum[axes=(1,)] dd df:f64[10000] = mul da s dg:f64[10000] = add w df dh:f64[10000] = mul dc s di:f64[10000] = add x dh dj:f64[10000] = mul de s dk:f64[10000] = add y dj dl:f64[10000] = mul dg s dm:f64[10000] = add t dl dn:f64[10000] = mul di s do:f64[10000] = add u dn dp:f64[10000] = mul dk s dq:f64[10000] = add v dp in (dm, do, dq, dg, di, dk) } ] h i j a b c d e f g in (k, l, m, n, o, p) }
Here's the jaxpr for
step_fn
: