Open carlosgmartin opened 1 year ago
I wonder if checkify nan_checker can narrow it down https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html?highlight=checkify#the-checkify-transformation
Also, do you have a jnp.where
in your f
? See
https://jax.readthedocs.io/en/latest/faq.html?highlight=nan#gradients-contain-nan-where-using-where
@zhangqiaorjc I tried using checkify with errors=checkify.all_checks
(is checkify.all_errors
a typo on the linked page?) and err.throw()
. No error was raised. To my knowledge, f
does not use jnp.where
.
Perhaps a clue could lie in the differences between how the gradients of _.max()
are computed and how the gradients of _[_.argmax()]
/ jnp.nanmax(_)
are computed?
You can also use jax_debug_nans
, though it might raise the error inside JAX's backward-pass code, which might be unfamiliar.
Any chance you could share a runnable repro? (The smaller the better!)
JAX treats differentiation of lambda x: x[x.argmax()]
differently from how it treats differentiation of lambda x: x.max()
. The difference is in how ties are handled, i.e. non-unique maxima:
import jax
def f1(x):
return x.max()
def f2(x):
return x[x.argmax()]
x = jnp.ones(2)
print(jax.grad(f1)(x)) # [0.5 0.5]
print(jax.grad(f2)(x)) # [1. 0.]
The gradient (in the sense of a representer vector for the Frechet derivative) is mathematically not well defined for (x, y) \mapsto max(x, y)
where x == y
. So we have to either error or choose some convention. One viable convention is to just pick one of the directional derivatives, e.g. corresponding to the choice of the first maximum (first according to, say, the left-to-right order in a flattened version of the input). Rather than arbitrarily choosing one of the maxima, another convention is to be symmetric under permutation of the maxima.
In JAX we chose the latter. (Actually we chose it when writing the original Autograd, then stuck with it for JAX.) But that means that lambda x: x[x.argmax()]
differentiates differently than lambda x: x.max()
, because argmax
is defined (in its docstring!) to choose the first maximum when there are multiple. (This pattern is common in autodiff: different programming-language denotations of the same mathematical function can lead to different automatic derivatives.)
A consequence is that there may be functions f
and values of x
such that grad(lambda x: f1(f(x))(x)
is nan
but grad(lambda x: f2(f(x))(x)
is not. But I'm interested to see if that's what's being constructed! Or maybe there's something else going on.
In other words, I'm just explaining why x.max()
and x[x.argmax()]
could behave differently under autodiff. But I don't know exactly what's going on without a repro.
@mattjj This function is buried deep inside a jitted part of the program and is entangled with a lot of other code. I'm having some trouble extracting a minimal example. Setting the JAX_DEBUG_NANS=True
environment variable yields
For some reason, the trace seems to go no "deeper" than the top-level scan
that creates the jitted part of the program.
It's interesting that jnp.nanmax
works when jnp.max
doesn't, despite the lack of nan
values. nanmax
also uses the permutation-symmetric convention you mentioned. Where can one find the gradient-computation code for these two functions, to compare them side by side? Perhaps carefully examining max
's gradient-computation code could reveal some kind of edge/failure case?
Regarding a reproducible example, is it possible to get some representation of f
with concrete values for all constants, while deep inside a jit? I tried setting up some code to manually detect when a nan
gradient appears and run jax.debug.callback(callback_fn)
if it does. However, playing around with f
inside that callback leads to complaints about leaks:
I also tried explicitly passing the relevant functions and values to jax.debug.callback
through the *args
, but it then complains that the functions are not valid JAX types.
I solved the above callback issues by using something of the form
jax.debug.callback(functools.partial(callback_fn, *all_static_args), *all_dynamic_args)
However, something strange happens: When I compute and print the same gradient inside the callback, the nan
s disappear and the values are correct (still using jnp.max
).
Description
I have a strange bug:
Letting
y = f(x)
: The gradient ofy.max()
is allnan
s, even though the individual values and gradients arenan
-free. Replacing it withy[y.argmax()]
orjnp.nanmax(y)
fixes this.make_jaxpr(f)(x)
is the following:Jaxpr
``` { lambda a:f64[1] b:f64[1] c:f64[1] d:f64[1] e:f64[1] f:f64[1] g:f64[1] h:f64[1] i:f64[1] j:f64[1] k:f64[1] l:f64[1] m:f64[2] n:f64[2] o:f64[2] p:f64[2] q:f64[2] r:f64[2] s:f64[2] t:f64[2] u:f64[2] v:f64[2] w:f64[2] x:f64[2] y:f64[2] z:f64[2] ba:f64[2] bb:f64[2] bc:f64[2] bd:f64[2] be:f64[2] bf:f64[2] bg:f64[2] bh:f64[2] bi:f64[2] bj:f64[2] bk:f64[2] bl:f64[2] bm:f64[2] bn:f64[2] bo:f64[2] bp:f64[2]; bq:f64[6]. let br:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] bq bs:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] bq bt:f64[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] bq bu:f64[1] = slice[limit_indices=(4,) start_indices=(3,) strides=None] bq bv:f64[1] = slice[limit_indices=(5,) start_indices=(4,) strides=None] bq bw:f64[1] = slice[limit_indices=(6,) start_indices=(5,) strides=None] bq bx:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] a by:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] b bz:f64[2,1] = concatenate[dimension=0] bx by ca:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] c cb:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] d cc:f64[2,1] = concatenate[dimension=0] ca cb cd:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] e ce:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] f cf:f64[2,1] = concatenate[dimension=0] cd ce cg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] g ch:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] h ci:f64[2,1] = concatenate[dimension=0] cg ch cj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] i ck:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] j cl:f64[2,1] = concatenate[dimension=0] cj ck cm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] k cn:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] l co:f64[2,1] = concatenate[dimension=0] cm cn cp:f64[2,2] = xla_call[ call_jaxpr={ lambda ; cq:f64[2,1] cr:i64[]. let cs:i64[1] = reshape[dimensions=None new_sizes=(1,)] cr ct:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] cs cu:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] ct cv:f64[2,2] = concatenate[dimension=1] cq cu in (cv,) } name=append ] bz 0 cw:f64[2] = reduce_max[axes=(1,)] cp cx:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] cw cy:f64[2,1] = stop_gradient cx cz:f64[2,2] = sub cp cy da:f64[2,2] = exp cz db:f64[2] = reduce_sum[axes=(1,)] da dc:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] db dd:f64[2,2] = div da dc de:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] dd df:f64[2] = squeeze[dimensions=(1,)] de dg:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] dd dh:f64[2] = squeeze[dimensions=(1,)] dg di:f64[2,2] = xla_call[ call_jaxpr={ lambda ; dj:f64[2,1] dk:i64[]. let dl:i64[1] = reshape[dimensions=None new_sizes=(1,)] dk dm:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] dl dn:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] dm do:f64[2,2] = concatenate[dimension=1] dj dn in (do,) } name=append ] cc 0 dp:f64[2] = reduce_max[axes=(1,)] di dq:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] dp dr:f64[2,1] = stop_gradient dq ds:f64[2,2] = sub di dr dt:f64[2,2] = exp ds du:f64[2] = reduce_sum[axes=(1,)] dt dv:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] du dw:f64[2,2] = div dt dv dx:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] dw dy:f64[2] = squeeze[dimensions=(1,)] dx dz:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] dw ea:f64[2] = squeeze[dimensions=(1,)] dz eb:f64[2,2] = xla_call[ call_jaxpr={ lambda ; ec:f64[2,1] ed:i64[]. let ee:i64[1] = reshape[dimensions=None new_sizes=(1,)] ed ef:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] ee eg:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] ef eh:f64[2,2] = concatenate[dimension=1] ec eg in (eh,) } name=append ] cf 0 ei:f64[2] = reduce_max[axes=(1,)] eb ej:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] ei ek:f64[2,1] = stop_gradient ej el:f64[2,2] = sub eb ek em:f64[2,2] = exp el en:f64[2] = reduce_sum[axes=(1,)] em eo:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] en ep:f64[2,2] = div em eo eq:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] ep er:f64[2] = squeeze[dimensions=(1,)] eq es:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] ep et:f64[2] = squeeze[dimensions=(1,)] es eu:f64[2,2] = xla_call[ call_jaxpr={ lambda ; ev:f64[2,1] ew:i64[]. let ex:i64[1] = reshape[dimensions=None new_sizes=(1,)] ew ey:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] ex ez:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] ey fa:f64[2,2] = concatenate[dimension=1] ev ez in (fa,) } name=append ] ci 0 fb:f64[2] = reduce_max[axes=(1,)] eu fc:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fb fd:f64[2,1] = stop_gradient fc fe:f64[2,2] = sub eu fd ff:f64[2,2] = exp fe fg:f64[2] = reduce_sum[axes=(1,)] ff fh:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fg fi:f64[2,2] = div ff fh fj:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] fi fk:f64[2] = squeeze[dimensions=(1,)] fj fl:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] fi fm:f64[2] = squeeze[dimensions=(1,)] fl fn:f64[2,2] = xla_call[ call_jaxpr={ lambda ; fo:f64[2,1] fp:i64[]. let fq:i64[1] = reshape[dimensions=None new_sizes=(1,)] fp fr:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] fq fs:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] fr ft:f64[2,2] = concatenate[dimension=1] fo fs in (ft,) } name=append ] cl 0 fu:f64[2] = reduce_max[axes=(1,)] fn fv:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fu fw:f64[2,1] = stop_gradient fv fx:f64[2,2] = sub fn fw fy:f64[2,2] = exp fx fz:f64[2] = reduce_sum[axes=(1,)] fy ga:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fz gb:f64[2,2] = div fy ga gc:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] gb gd:f64[2] = squeeze[dimensions=(1,)] gc ge:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] gb gf:f64[2] = squeeze[dimensions=(1,)] ge gg:f64[2,2] = xla_call[ call_jaxpr={ lambda ; gh:f64[2,1] gi:i64[]. let gj:i64[1] = reshape[dimensions=None new_sizes=(1,)] gi gk:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] gj gl:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] gk gm:f64[2,2] = concatenate[dimension=1] gh gl in (gm,) } name=append ] co 0 gn:f64[2] = reduce_max[axes=(1,)] gg go:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gn gp:f64[2,1] = stop_gradient go gq:f64[2,2] = sub gg gp gr:f64[2,2] = exp gq gs:f64[2] = reduce_sum[axes=(1,)] gr gt:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gs gu:f64[2,2] = div gr gt gv:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] gu gw:f64[2] = squeeze[dimensions=(1,)] gv gx:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] gu gy:f64[2] = squeeze[dimensions=(1,)] gx gz:f64[2] = xla_call[ call_jaxpr={ lambda ; ha:f64[1] hb:i64[]. let hc:i64[1] = reshape[dimensions=None new_sizes=(1,)] hb hd:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] hc he:f64[2] = concatenate[dimension=0] ha hd in (he,) } name=append ] bu 0 hf:f64[] = reduce_max[axes=(0,)] gz hg:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] hf hh:f64[1] = stop_gradient hg hi:f64[2] = sub gz hh hj:f64[2] = exp hi hk:f64[] = reduce_sum[axes=(0,)] hj hl:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] hk hm:f64[2] = div hj hl hn:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] hm ho:f64[] = squeeze[dimensions=(0,)] hn hp:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] hm hq:f64[] = squeeze[dimensions=(0,)] hp hr:f64[2] = xla_call[ call_jaxpr={ lambda ; hs:f64[1] ht:i64[]. let hu:i64[1] = reshape[dimensions=None new_sizes=(1,)] ht hv:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] hu hw:f64[2] = concatenate[dimension=0] hs hv in (hw,) } name=append ] bt 0 hx:f64[] = reduce_max[axes=(0,)] hr hy:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] hx hz:f64[1] = stop_gradient hy ia:f64[2] = sub hr hz ib:f64[2] = exp ia ic:f64[] = reduce_sum[axes=(0,)] ib id:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ic ie:f64[2] = div ib id if:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] ie ig:f64[] = squeeze[dimensions=(0,)] if ih:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] ie ii:f64[] = squeeze[dimensions=(0,)] ih ij:f64[2] = xla_call[ call_jaxpr={ lambda ; ik:f64[1] il:i64[]. let im:i64[1] = reshape[dimensions=None new_sizes=(1,)] il in:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] im io:f64[2] = concatenate[dimension=0] ik in in (io,) } name=append ] bw 0 ip:f64[] = reduce_max[axes=(0,)] ij iq:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ip ir:f64[1] = stop_gradient iq is:f64[2] = sub ij ir it:f64[2] = exp is iu:f64[] = reduce_sum[axes=(0,)] it iv:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] iu iw:f64[2] = div it iv ix:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] iw iy:f64[] = squeeze[dimensions=(0,)] ix iz:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] iw ja:f64[] = squeeze[dimensions=(0,)] iz jb:f64[2] = xla_call[ call_jaxpr={ lambda ; jc:f64[1] jd:i64[]. let je:i64[1] = reshape[dimensions=None new_sizes=(1,)] jd jf:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] je jg:f64[2] = concatenate[dimension=0] jc jf in (jg,) } name=append ] bv 0 jh:f64[] = reduce_max[axes=(0,)] jb ji:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] jh jj:f64[1] = stop_gradient ji jk:f64[2] = sub jb jj jl:f64[2] = exp jk jm:f64[] = reduce_sum[axes=(0,)] jl jn:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] jm jo:f64[2] = div jl jn jp:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] jo jq:f64[] = squeeze[dimensions=(0,)] jp jr:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] jo js:f64[] = squeeze[dimensions=(0,)] jr jt:f64[2] = xla_call[ call_jaxpr={ lambda ; ju:f64[1] jv:i64[]. let jw:i64[1] = reshape[dimensions=None new_sizes=(1,)] jv jx:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] jw jy:f64[2] = concatenate[dimension=0] ju jx in (jy,) } name=append ] bs 0 jz:f64[] = reduce_max[axes=(0,)] jt ka:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] jz kb:f64[1] = stop_gradient ka kc:f64[2] = sub jt kb kd:f64[2] = exp kc ke:f64[] = reduce_sum[axes=(0,)] kd kf:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ke kg:f64[2] = div kd kf kh:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] kg ki:f64[] = squeeze[dimensions=(0,)] kh kj:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] kg kk:f64[] = squeeze[dimensions=(0,)] kj kl:f64[2] = xla_call[ call_jaxpr={ lambda ; km:f64[1] kn:i64[]. let ko:i64[1] = reshape[dimensions=None new_sizes=(1,)] kn kp:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] ko kq:f64[2] = concatenate[dimension=0] km kp in (kq,) } name=append ] br 0 kr:f64[] = reduce_max[axes=(0,)] kl ks:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] kr kt:f64[1] = stop_gradient ks ku:f64[2] = sub kl kt kv:f64[2] = exp ku kw:f64[] = reduce_sum[axes=(0,)] kv kx:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] kw ky:f64[2] = div kv kx kz:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] ky la:f64[] = squeeze[dimensions=(0,)] kz lb:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] ky lc:f64[] = squeeze[dimensions=(0,)] lb ld:f64[2] = mul ho m le:f64[2] = add 0.0 ld lf:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] n lg:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] dy lh:f64[2,2] = mul lg lf li:f64[2,2] = add 0.0 lh lj:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] o lk:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] ea ll:f64[2,2] = mul lk lj lm:f64[2,2] = add li ll ln:f64[2,2] = mul hq lm lo:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] le lp:f64[2,2] = add lo ln lq:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] df lr:f64[2,2] = mul lq lp ls:f64[2,2] = add 0.0 lr lt:f64[2] = mul ig p lu:f64[2] = add 0.0 lt lv:f64[2] = mul ii q lw:f64[2] = add lu lv lx:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] lw ly:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] dh lz:f64[2,2] = mul ly lx ma:f64[2,2] = add ls lz mb:f64[2,2] = mul 0.5 ma mc:f64[2,2] = add 0.0 mb md:f64[2] = mul iy r me:f64[2] = add 0.0 md mf:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] s mg:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] dy mh:f64[2,2] = mul mg mf mi:f64[2,2] = add 0.0 mh mj:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] t mk:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] ea ml:f64[2,2] = mul mk mj mm:f64[2,2] = add mi ml mn:f64[2,2] = mul ja mm mo:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] me mp:f64[2,2] = add mo mn mq:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] df mr:f64[2,2] = mul mq mp ms:f64[2,2] = add 0.0 mr mt:f64[2] = mul jq u mu:f64[2] = add 0.0 mt mv:f64[2] = mul js v mw:f64[2] = add mu mv mx:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] mw my:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] dh mz:f64[2,2] = mul my mx na:f64[2,2] = add ms mz nb:f64[2,2] = mul 0.5 na nc:f64[2,2] = add mc nb nd:f64[2,2] = mul 0.3333333333333333 nc ne:f64[2,2] = add 0.0 nd nf:f64[2] = mul ki w ng:f64[2] = add 0.0 nf nh:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] x ni:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fk nj:f64[2,2] = mul ni nh nk:f64[2,2] = add 0.0 nj nl:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] y nm:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fm nn:f64[2,2] = mul nm nl no:f64[2,2] = add nk nn np:f64[2,2] = mul kk no nq:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] ng nr:f64[2,2] = add nq np ns:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] er nt:f64[2,2] = mul ns nr nu:f64[2,2] = add 0.0 nt nv:f64[2] = mul la z nw:f64[2] = add 0.0 nv nx:f64[2] = mul lc ba ny:f64[2] = add nw nx nz:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] ny oa:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] et ob:f64[2,2] = mul oa nz oc:f64[2,2] = add nu ob od:f64[2,2] = mul 0.5 oc oe:f64[2,2] = add 0.0 od of:f64[2] = mul iy bb og:f64[2] = add 0.0 of oh:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bc oi:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fk oj:f64[2,2] = mul oi oh ok:f64[2,2] = add 0.0 oj ol:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bd om:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fm on:f64[2,2] = mul om ol oo:f64[2,2] = add ok on op:f64[2,2] = mul ja oo oq:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] og or:f64[2,2] = add oq op os:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] er ot:f64[2,2] = mul os or ou:f64[2,2] = add 0.0 ot ov:f64[2] = mul jq be ow:f64[2] = add 0.0 ov ox:f64[2] = mul js bf oy:f64[2] = add ow ox oz:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] oy pa:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] et pb:f64[2,2] = mul pa oz pc:f64[2,2] = add ou pb pd:f64[2,2] = mul 0.5 pc pe:f64[2,2] = add oe pd pf:f64[2,2] = mul 0.3333333333333333 pe pg:f64[2,2] = add ne pf ph:f64[2] = mul ki bg pi:f64[2] = add 0.0 ph pj:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bh pk:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gw pl:f64[2,2] = mul pk pj pm:f64[2,2] = add 0.0 pl pn:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bi po:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gy pp:f64[2,2] = mul po pn pq:f64[2,2] = add pm pp pr:f64[2,2] = mul kk pq ps:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] pi pt:f64[2,2] = add ps pr pu:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gd pv:f64[2,2] = mul pu pt pw:f64[2,2] = add 0.0 pv px:f64[2] = mul la bj py:f64[2] = add 0.0 px pz:f64[2] = mul lc bk qa:f64[2] = add py pz qb:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] qa qc:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gf qd:f64[2,2] = mul qc qb qe:f64[2,2] = add pw qd qf:f64[2,2] = mul 0.5 qe qg:f64[2,2] = add 0.0 qf qh:f64[2] = mul ho bl qi:f64[2] = add 0.0 qh qj:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bm qk:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gw ql:f64[2,2] = mul qk qj qm:f64[2,2] = add 0.0 ql qn:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bn qo:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gy qp:f64[2,2] = mul qo qn qq:f64[2,2] = add qm qp qr:f64[2,2] = mul hq qq qs:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] qi qt:f64[2,2] = add qs qr qu:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gd qv:f64[2,2] = mul qu qt qw:f64[2,2] = add 0.0 qv qx:f64[2] = mul ig bo qy:f64[2] = add 0.0 qx qz:f64[2] = mul ii bp ra:f64[2] = add qy qz rb:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] ra rc:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gf rd:f64[2,2] = mul rc rb re:f64[2,2] = add qw rd rf:f64[2,2] = mul 0.5 re rg:f64[2,2] = add qg rf rh:f64[2,2] = mul 0.3333333333333333 rg ri:f64[2,2] = add pg rh rj:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0 rk:f64[2,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(2, 1) unique_indices=True ] ri rj rl:f64[2] = squeeze[dimensions=(1,)] rk in (rl,) } ```What jax/jaxlib version are you using?
jax 0.3.25, jaxlib 0.3.25
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.10.8, macOS 11.7
NVIDIA GPU info
No response