robelgeda / peytonites2024

Repository for the Princeton Astrophysical team at NVIDIA + Princeton Open Hackathon
https://peytonites2024.readthedocs.io/en/latest/index.html
MIT License
1 stars 1 forks source link

JAX info #22

Open henryiii opened 1 month ago

henryiii commented 1 month ago

Here's the jaxpr for step_fn:

{ lambda ; a:f64[10000] b:f64[10000] c:f64[10000] d:f64[10000] e:f64[10000] f:f64[10000]
    g:f64[10000] h:f64[] i:f64[] j:f64[]. let
    k:f64[10000] l:f64[10000] m:f64[10000] n:f64[10000] o:f64[10000] p:f64[10000] = pjit[
      name=step_fun
      jaxpr={ lambda ; q:f64[] r:f64[] s:f64[] t:f64[10000] u:f64[10000] v:f64[10000]
          w:f64[10000] x:f64[10000] y:f64[10000] z:f64[10000]. let
          ba:f64[10000,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(10000, 1)
          ] t
          bb:f64[1,10000] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 10000)
          ] t
          bc:f64[10000,10000] = sub ba bb
          bd:f64[10000,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(10000, 1)
          ] u
          be:f64[1,10000] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 10000)
          ] u
          bf:f64[10000,10000] = sub bd be
          bg:f64[10000,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(10000, 1)
          ] v
          bh:f64[1,10000] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 10000)
          ] v
          bi:f64[10000,10000] = sub bg bh
          bj:f64[10000,10000] = integer_pow[y=2] bc
          bk:f64[10000,10000] = integer_pow[y=2] bf
          bl:f64[10000,10000] = add bj bk
          bm:f64[10000,10000] = integer_pow[y=2] bi
          bn:f64[10000,10000] = add bl bm
          bo:f64[] = integer_pow[y=2] q
          bp:f64[10000,10000] = add bn bo
          bq:i64[10000] = iota[dimension=0 dtype=int64 shape=(10000,)]
          br:bool[10000] = lt bq 0
          bs:i64[10000] = add bq 10000
          bt:i64[10000] = select_n br bq bs
          bu:bool[10000] = lt bq 0
          bv:i64[10000] = add bq 10000
          bw:i64[10000] = select_n bu bq bv
          bx:i32[10000] = convert_element_type[new_dtype=int32 weak_type=False] bt
          by:i32[10000] = convert_element_type[new_dtype=int32 weak_type=False] bw
          bz:i32[10000,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(10000, 1)
          ] bx
          ca:i32[10000,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(10000, 1)
          ] by
          cb:i32[10000,2] = concatenate[dimension=1] bz ca
          cc:f64[10000] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(10000,)
          ] 1.0
          cd:f64[10000,10000] = scatter[
            dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1))
            indices_are_sorted=False
            mode=GatherScatterMode.FILL_OR_DROP
            unique_indices=False
            update_consts=()
            update_jaxpr=None
          ] bp cb cc
          ce:f64[10000,10000] = sqrt cd
          cf:f64[10000,10000] = mul cd ce
          cg:f64[] = neg r
          ch:f64[] = convert_element_type[new_dtype=float64 weak_type=False] cg
          ci:f64[10000] = mul ch z
          cj:f64[1,10000] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 10000)
          ] ci
          ck:f64[10000,10000] = div cj cf
          cl:i64[10000] = iota[dimension=0 dtype=int64 shape=(10000,)]
          cm:bool[10000] = lt cl 0
          cn:i64[10000] = add cl 10000
          co:i64[10000] = select_n cm cl cn
          cp:bool[10000] = lt cl 0
          cq:i64[10000] = add cl 10000
          cr:i64[10000] = select_n cp cl cq
          cs:i32[10000] = convert_element_type[new_dtype=int32 weak_type=False] co
          ct:i32[10000] = convert_element_type[new_dtype=int32 weak_type=False] cr
          cu:i32[10000,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(10000, 1)
          ] cs
          cv:i32[10000,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(10000, 1)
          ] ct
          cw:i32[10000,2] = concatenate[dimension=1] cu cv
          cx:f64[10000] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(10000,)
          ] 0.0
          cy:f64[10000,10000] = scatter[
            dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1))
            indices_are_sorted=False
            mode=GatherScatterMode.FILL_OR_DROP
            unique_indices=False
            update_consts=()
            update_jaxpr=None
          ] ck cw cx
          cz:f64[10000,10000] = mul cy bc
          da:f64[10000] = reduce_sum[axes=(1,)] cz
          db:f64[10000,10000] = mul cy bf
          dc:f64[10000] = reduce_sum[axes=(1,)] db
          dd:f64[10000,10000] = mul cy bi
          de:f64[10000] = reduce_sum[axes=(1,)] dd
          df:f64[10000] = mul da s
          dg:f64[10000] = add w df
          dh:f64[10000] = mul dc s
          di:f64[10000] = add x dh
          dj:f64[10000] = mul de s
          dk:f64[10000] = add y dj
          dl:f64[10000] = mul dg s
          dm:f64[10000] = add t dl
          dn:f64[10000] = mul di s
          do:f64[10000] = add u dn
          dp:f64[10000] = mul dk s
          dq:f64[10000] = add v dp
        in (dm, do, dq, dg, di, dk) }
    ] h i j a b c d e f g
  in (k, l, m, n, o, p) }