jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

.max() yields nan gradient despite individual values and gradients being nan-free #13600

Open carlosgmartin opened 1 year ago

carlosgmartin commented 1 year ago

Description

I have a strange bug:

x = [ 0.23814293  0.77400221  0.33243385  1.69167768 -0.17876684 -1.5681684 ]
f(x) = [0.2637331  0.25321472]
f(x).max() = 0.2637331010799241
grad(lambda x: f(x)[0])(x) = [-0.06657368  0.0116883   0.03331883  0.00657176  0.16408419  0.00699502]
grad(lambda x: f(x)[1])(x) = [-0.04887613  0.02342647  0.07204693  0.0035732   0.16981069  0.00377417]
grad(lambda x: f(x).max())(x) = [nan nan nan nan nan nan]
grad(lambda x: f(x)[f(x).argmax()])(x) = [-0.06657368  0.0116883   0.03331883  0.00657176  0.16408419  0.00699502]
grad(lambda x: jnp.nanmax(f(x)))(x) = [-0.06657368  0.0116883   0.03331883  0.00657176  0.16408419  0.00699502]

Letting y = f(x): The gradient of y.max() is all nans, even though the individual values and gradients are nan-free. Replacing it with y[y.argmax()] or jnp.nanmax(y) fixes this.

make_jaxpr(f)(x) is the following:

Jaxpr ``` { lambda a:f64[1] b:f64[1] c:f64[1] d:f64[1] e:f64[1] f:f64[1] g:f64[1] h:f64[1] i:f64[1] j:f64[1] k:f64[1] l:f64[1] m:f64[2] n:f64[2] o:f64[2] p:f64[2] q:f64[2] r:f64[2] s:f64[2] t:f64[2] u:f64[2] v:f64[2] w:f64[2] x:f64[2] y:f64[2] z:f64[2] ba:f64[2] bb:f64[2] bc:f64[2] bd:f64[2] be:f64[2] bf:f64[2] bg:f64[2] bh:f64[2] bi:f64[2] bj:f64[2] bk:f64[2] bl:f64[2] bm:f64[2] bn:f64[2] bo:f64[2] bp:f64[2]; bq:f64[6]. let br:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] bq bs:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] bq bt:f64[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] bq bu:f64[1] = slice[limit_indices=(4,) start_indices=(3,) strides=None] bq bv:f64[1] = slice[limit_indices=(5,) start_indices=(4,) strides=None] bq bw:f64[1] = slice[limit_indices=(6,) start_indices=(5,) strides=None] bq bx:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] a by:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] b bz:f64[2,1] = concatenate[dimension=0] bx by ca:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] c cb:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] d cc:f64[2,1] = concatenate[dimension=0] ca cb cd:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] e ce:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] f cf:f64[2,1] = concatenate[dimension=0] cd ce cg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] g ch:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] h ci:f64[2,1] = concatenate[dimension=0] cg ch cj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] i ck:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] j cl:f64[2,1] = concatenate[dimension=0] cj ck cm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] k cn:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] l co:f64[2,1] = concatenate[dimension=0] cm cn cp:f64[2,2] = xla_call[ call_jaxpr={ lambda ; cq:f64[2,1] cr:i64[]. let cs:i64[1] = reshape[dimensions=None new_sizes=(1,)] cr ct:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] cs cu:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] ct cv:f64[2,2] = concatenate[dimension=1] cq cu in (cv,) } name=append ] bz 0 cw:f64[2] = reduce_max[axes=(1,)] cp cx:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] cw cy:f64[2,1] = stop_gradient cx cz:f64[2,2] = sub cp cy da:f64[2,2] = exp cz db:f64[2] = reduce_sum[axes=(1,)] da dc:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] db dd:f64[2,2] = div da dc de:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] dd df:f64[2] = squeeze[dimensions=(1,)] de dg:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] dd dh:f64[2] = squeeze[dimensions=(1,)] dg di:f64[2,2] = xla_call[ call_jaxpr={ lambda ; dj:f64[2,1] dk:i64[]. let dl:i64[1] = reshape[dimensions=None new_sizes=(1,)] dk dm:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] dl dn:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] dm do:f64[2,2] = concatenate[dimension=1] dj dn in (do,) } name=append ] cc 0 dp:f64[2] = reduce_max[axes=(1,)] di dq:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] dp dr:f64[2,1] = stop_gradient dq ds:f64[2,2] = sub di dr dt:f64[2,2] = exp ds du:f64[2] = reduce_sum[axes=(1,)] dt dv:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] du dw:f64[2,2] = div dt dv dx:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] dw dy:f64[2] = squeeze[dimensions=(1,)] dx dz:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] dw ea:f64[2] = squeeze[dimensions=(1,)] dz eb:f64[2,2] = xla_call[ call_jaxpr={ lambda ; ec:f64[2,1] ed:i64[]. let ee:i64[1] = reshape[dimensions=None new_sizes=(1,)] ed ef:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] ee eg:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] ef eh:f64[2,2] = concatenate[dimension=1] ec eg in (eh,) } name=append ] cf 0 ei:f64[2] = reduce_max[axes=(1,)] eb ej:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] ei ek:f64[2,1] = stop_gradient ej el:f64[2,2] = sub eb ek em:f64[2,2] = exp el en:f64[2] = reduce_sum[axes=(1,)] em eo:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] en ep:f64[2,2] = div em eo eq:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] ep er:f64[2] = squeeze[dimensions=(1,)] eq es:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] ep et:f64[2] = squeeze[dimensions=(1,)] es eu:f64[2,2] = xla_call[ call_jaxpr={ lambda ; ev:f64[2,1] ew:i64[]. let ex:i64[1] = reshape[dimensions=None new_sizes=(1,)] ew ey:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] ex ez:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] ey fa:f64[2,2] = concatenate[dimension=1] ev ez in (fa,) } name=append ] ci 0 fb:f64[2] = reduce_max[axes=(1,)] eu fc:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fb fd:f64[2,1] = stop_gradient fc fe:f64[2,2] = sub eu fd ff:f64[2,2] = exp fe fg:f64[2] = reduce_sum[axes=(1,)] ff fh:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fg fi:f64[2,2] = div ff fh fj:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] fi fk:f64[2] = squeeze[dimensions=(1,)] fj fl:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] fi fm:f64[2] = squeeze[dimensions=(1,)] fl fn:f64[2,2] = xla_call[ call_jaxpr={ lambda ; fo:f64[2,1] fp:i64[]. let fq:i64[1] = reshape[dimensions=None new_sizes=(1,)] fp fr:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] fq fs:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] fr ft:f64[2,2] = concatenate[dimension=1] fo fs in (ft,) } name=append ] cl 0 fu:f64[2] = reduce_max[axes=(1,)] fn fv:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fu fw:f64[2,1] = stop_gradient fv fx:f64[2,2] = sub fn fw fy:f64[2,2] = exp fx fz:f64[2] = reduce_sum[axes=(1,)] fy ga:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fz gb:f64[2,2] = div fy ga gc:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] gb gd:f64[2] = squeeze[dimensions=(1,)] gc ge:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] gb gf:f64[2] = squeeze[dimensions=(1,)] ge gg:f64[2,2] = xla_call[ call_jaxpr={ lambda ; gh:f64[2,1] gi:i64[]. let gj:i64[1] = reshape[dimensions=None new_sizes=(1,)] gi gk:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] gj gl:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(2, 1)] gk gm:f64[2,2] = concatenate[dimension=1] gh gl in (gm,) } name=append ] co 0 gn:f64[2] = reduce_max[axes=(1,)] gg go:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gn gp:f64[2,1] = stop_gradient go gq:f64[2,2] = sub gg gp gr:f64[2,2] = exp gq gs:f64[2] = reduce_sum[axes=(1,)] gr gt:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gs gu:f64[2,2] = div gr gt gv:f64[2,1] = slice[ limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1) ] gu gw:f64[2] = squeeze[dimensions=(1,)] gv gx:f64[2,1] = slice[ limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1) ] gu gy:f64[2] = squeeze[dimensions=(1,)] gx gz:f64[2] = xla_call[ call_jaxpr={ lambda ; ha:f64[1] hb:i64[]. let hc:i64[1] = reshape[dimensions=None new_sizes=(1,)] hb hd:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] hc he:f64[2] = concatenate[dimension=0] ha hd in (he,) } name=append ] bu 0 hf:f64[] = reduce_max[axes=(0,)] gz hg:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] hf hh:f64[1] = stop_gradient hg hi:f64[2] = sub gz hh hj:f64[2] = exp hi hk:f64[] = reduce_sum[axes=(0,)] hj hl:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] hk hm:f64[2] = div hj hl hn:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] hm ho:f64[] = squeeze[dimensions=(0,)] hn hp:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] hm hq:f64[] = squeeze[dimensions=(0,)] hp hr:f64[2] = xla_call[ call_jaxpr={ lambda ; hs:f64[1] ht:i64[]. let hu:i64[1] = reshape[dimensions=None new_sizes=(1,)] ht hv:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] hu hw:f64[2] = concatenate[dimension=0] hs hv in (hw,) } name=append ] bt 0 hx:f64[] = reduce_max[axes=(0,)] hr hy:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] hx hz:f64[1] = stop_gradient hy ia:f64[2] = sub hr hz ib:f64[2] = exp ia ic:f64[] = reduce_sum[axes=(0,)] ib id:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ic ie:f64[2] = div ib id if:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] ie ig:f64[] = squeeze[dimensions=(0,)] if ih:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] ie ii:f64[] = squeeze[dimensions=(0,)] ih ij:f64[2] = xla_call[ call_jaxpr={ lambda ; ik:f64[1] il:i64[]. let im:i64[1] = reshape[dimensions=None new_sizes=(1,)] il in:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] im io:f64[2] = concatenate[dimension=0] ik in in (io,) } name=append ] bw 0 ip:f64[] = reduce_max[axes=(0,)] ij iq:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ip ir:f64[1] = stop_gradient iq is:f64[2] = sub ij ir it:f64[2] = exp is iu:f64[] = reduce_sum[axes=(0,)] it iv:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] iu iw:f64[2] = div it iv ix:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] iw iy:f64[] = squeeze[dimensions=(0,)] ix iz:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] iw ja:f64[] = squeeze[dimensions=(0,)] iz jb:f64[2] = xla_call[ call_jaxpr={ lambda ; jc:f64[1] jd:i64[]. let je:i64[1] = reshape[dimensions=None new_sizes=(1,)] jd jf:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] je jg:f64[2] = concatenate[dimension=0] jc jf in (jg,) } name=append ] bv 0 jh:f64[] = reduce_max[axes=(0,)] jb ji:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] jh jj:f64[1] = stop_gradient ji jk:f64[2] = sub jb jj jl:f64[2] = exp jk jm:f64[] = reduce_sum[axes=(0,)] jl jn:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] jm jo:f64[2] = div jl jn jp:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] jo jq:f64[] = squeeze[dimensions=(0,)] jp jr:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] jo js:f64[] = squeeze[dimensions=(0,)] jr jt:f64[2] = xla_call[ call_jaxpr={ lambda ; ju:f64[1] jv:i64[]. let jw:i64[1] = reshape[dimensions=None new_sizes=(1,)] jv jx:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] jw jy:f64[2] = concatenate[dimension=0] ju jx in (jy,) } name=append ] bs 0 jz:f64[] = reduce_max[axes=(0,)] jt ka:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] jz kb:f64[1] = stop_gradient ka kc:f64[2] = sub jt kb kd:f64[2] = exp kc ke:f64[] = reduce_sum[axes=(0,)] kd kf:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ke kg:f64[2] = div kd kf kh:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] kg ki:f64[] = squeeze[dimensions=(0,)] kh kj:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] kg kk:f64[] = squeeze[dimensions=(0,)] kj kl:f64[2] = xla_call[ call_jaxpr={ lambda ; km:f64[1] kn:i64[]. let ko:i64[1] = reshape[dimensions=None new_sizes=(1,)] kn kp:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] ko kq:f64[2] = concatenate[dimension=0] km kp in (kq,) } name=append ] br 0 kr:f64[] = reduce_max[axes=(0,)] kl ks:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] kr kt:f64[1] = stop_gradient ks ku:f64[2] = sub kl kt kv:f64[2] = exp ku kw:f64[] = reduce_sum[axes=(0,)] kv kx:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] kw ky:f64[2] = div kv kx kz:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] ky la:f64[] = squeeze[dimensions=(0,)] kz lb:f64[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] ky lc:f64[] = squeeze[dimensions=(0,)] lb ld:f64[2] = mul ho m le:f64[2] = add 0.0 ld lf:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] n lg:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] dy lh:f64[2,2] = mul lg lf li:f64[2,2] = add 0.0 lh lj:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] o lk:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] ea ll:f64[2,2] = mul lk lj lm:f64[2,2] = add li ll ln:f64[2,2] = mul hq lm lo:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] le lp:f64[2,2] = add lo ln lq:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] df lr:f64[2,2] = mul lq lp ls:f64[2,2] = add 0.0 lr lt:f64[2] = mul ig p lu:f64[2] = add 0.0 lt lv:f64[2] = mul ii q lw:f64[2] = add lu lv lx:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] lw ly:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] dh lz:f64[2,2] = mul ly lx ma:f64[2,2] = add ls lz mb:f64[2,2] = mul 0.5 ma mc:f64[2,2] = add 0.0 mb md:f64[2] = mul iy r me:f64[2] = add 0.0 md mf:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] s mg:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] dy mh:f64[2,2] = mul mg mf mi:f64[2,2] = add 0.0 mh mj:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] t mk:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] ea ml:f64[2,2] = mul mk mj mm:f64[2,2] = add mi ml mn:f64[2,2] = mul ja mm mo:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] me mp:f64[2,2] = add mo mn mq:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] df mr:f64[2,2] = mul mq mp ms:f64[2,2] = add 0.0 mr mt:f64[2] = mul jq u mu:f64[2] = add 0.0 mt mv:f64[2] = mul js v mw:f64[2] = add mu mv mx:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] mw my:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] dh mz:f64[2,2] = mul my mx na:f64[2,2] = add ms mz nb:f64[2,2] = mul 0.5 na nc:f64[2,2] = add mc nb nd:f64[2,2] = mul 0.3333333333333333 nc ne:f64[2,2] = add 0.0 nd nf:f64[2] = mul ki w ng:f64[2] = add 0.0 nf nh:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] x ni:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fk nj:f64[2,2] = mul ni nh nk:f64[2,2] = add 0.0 nj nl:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] y nm:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fm nn:f64[2,2] = mul nm nl no:f64[2,2] = add nk nn np:f64[2,2] = mul kk no nq:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] ng nr:f64[2,2] = add nq np ns:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] er nt:f64[2,2] = mul ns nr nu:f64[2,2] = add 0.0 nt nv:f64[2] = mul la z nw:f64[2] = add 0.0 nv nx:f64[2] = mul lc ba ny:f64[2] = add nw nx nz:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] ny oa:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] et ob:f64[2,2] = mul oa nz oc:f64[2,2] = add nu ob od:f64[2,2] = mul 0.5 oc oe:f64[2,2] = add 0.0 od of:f64[2] = mul iy bb og:f64[2] = add 0.0 of oh:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bc oi:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fk oj:f64[2,2] = mul oi oh ok:f64[2,2] = add 0.0 oj ol:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bd om:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] fm on:f64[2,2] = mul om ol oo:f64[2,2] = add ok on op:f64[2,2] = mul ja oo oq:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] og or:f64[2,2] = add oq op os:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] er ot:f64[2,2] = mul os or ou:f64[2,2] = add 0.0 ot ov:f64[2] = mul jq be ow:f64[2] = add 0.0 ov ox:f64[2] = mul js bf oy:f64[2] = add ow ox oz:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] oy pa:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] et pb:f64[2,2] = mul pa oz pc:f64[2,2] = add ou pb pd:f64[2,2] = mul 0.5 pc pe:f64[2,2] = add oe pd pf:f64[2,2] = mul 0.3333333333333333 pe pg:f64[2,2] = add ne pf ph:f64[2] = mul ki bg pi:f64[2] = add 0.0 ph pj:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bh pk:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gw pl:f64[2,2] = mul pk pj pm:f64[2,2] = add 0.0 pl pn:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bi po:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gy pp:f64[2,2] = mul po pn pq:f64[2,2] = add pm pp pr:f64[2,2] = mul kk pq ps:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] pi pt:f64[2,2] = add ps pr pu:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gd pv:f64[2,2] = mul pu pt pw:f64[2,2] = add 0.0 pv px:f64[2] = mul la bj py:f64[2] = add 0.0 px pz:f64[2] = mul lc bk qa:f64[2] = add py pz qb:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] qa qc:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gf qd:f64[2,2] = mul qc qb qe:f64[2,2] = add pw qd qf:f64[2,2] = mul 0.5 qe qg:f64[2,2] = add 0.0 qf qh:f64[2] = mul ho bl qi:f64[2] = add 0.0 qh qj:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bm qk:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gw ql:f64[2,2] = mul qk qj qm:f64[2,2] = add 0.0 ql qn:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] bn qo:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gy qp:f64[2,2] = mul qo qn qq:f64[2,2] = add qm qp qr:f64[2,2] = mul hq qq qs:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] qi qt:f64[2,2] = add qs qr qu:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gd qv:f64[2,2] = mul qu qt qw:f64[2,2] = add 0.0 qv qx:f64[2] = mul ig bo qy:f64[2] = add 0.0 qx qz:f64[2] = mul ii bp ra:f64[2] = add qy qz rb:f64[1,2] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 2)] ra rc:f64[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] gf rd:f64[2,2] = mul rc rb re:f64[2,2] = add qw rd rf:f64[2,2] = mul 0.5 re rg:f64[2,2] = add qg rf rh:f64[2,2] = mul 0.3333333333333333 rg ri:f64[2,2] = add pg rh rj:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0 rk:f64[2,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(2, 1) unique_indices=True ] ri rj rl:f64[2] = squeeze[dimensions=(1,)] rk in (rl,) } ```

What jax/jaxlib version are you using?

jax 0.3.25, jaxlib 0.3.25

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.10.8, macOS 11.7

NVIDIA GPU info

No response

zhangqiaorjc commented 1 year ago

I wonder if checkify nan_checker can narrow it down https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html?highlight=checkify#the-checkify-transformation

Also, do you have a jnp.where in your f? See https://jax.readthedocs.io/en/latest/faq.html?highlight=nan#gradients-contain-nan-where-using-where

carlosgmartin commented 1 year ago

@zhangqiaorjc I tried using checkify with errors=checkify.all_checks (is checkify.all_errors a typo on the linked page?) and err.throw(). No error was raised. To my knowledge, f does not use jnp.where.

Perhaps a clue could lie in the differences between how the gradients of _.max() are computed and how the gradients of _[_.argmax()] / jnp.nanmax(_) are computed?

mattjj commented 1 year ago

You can also use jax_debug_nans, though it might raise the error inside JAX's backward-pass code, which might be unfamiliar.

Any chance you could share a runnable repro? (The smaller the better!)

mattjj commented 1 year ago

JAX treats differentiation of lambda x: x[x.argmax()] differently from how it treats differentiation of lambda x: x.max(). The difference is in how ties are handled, i.e. non-unique maxima:

import jax

def f1(x):
  return x.max()

def f2(x):
  return x[x.argmax()]

x = jnp.ones(2)

print(jax.grad(f1)(x))  # [0.5 0.5]
print(jax.grad(f2)(x))  # [1. 0.]

The gradient (in the sense of a representer vector for the Frechet derivative) is mathematically not well defined for (x, y) \mapsto max(x, y) where x == y. So we have to either error or choose some convention. One viable convention is to just pick one of the directional derivatives, e.g. corresponding to the choice of the first maximum (first according to, say, the left-to-right order in a flattened version of the input). Rather than arbitrarily choosing one of the maxima, another convention is to be symmetric under permutation of the maxima.

In JAX we chose the latter. (Actually we chose it when writing the original Autograd, then stuck with it for JAX.) But that means that lambda x: x[x.argmax()] differentiates differently than lambda x: x.max(), because argmax is defined (in its docstring!) to choose the first maximum when there are multiple. (This pattern is common in autodiff: different programming-language denotations of the same mathematical function can lead to different automatic derivatives.)

A consequence is that there may be functions f and values of x such that grad(lambda x: f1(f(x))(x) is nan but grad(lambda x: f2(f(x))(x) is not. But I'm interested to see if that's what's being constructed! Or maybe there's something else going on.

In other words, I'm just explaining why x.max() and x[x.argmax()] could behave differently under autodiff. But I don't know exactly what's going on without a repro.

carlosgmartin commented 1 year ago

@mattjj This function is buried deep inside a jitted part of the program and is entangled with a lot of other code. I'm having some trouble extracting a minimal example. Setting the JAX_DEBUG_NANS=True environment variable yields

``` Traceback (most recent call last): [...] return lax.scan(lambda x, _: (f(x), x), x, None, n + 1)[1] File "/usr/local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 275, in scan out = scan_p.bind(*consts, *in_flat, File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1000, in scan_bind return core.AxisPrimitive.bind(scan_p, *args, **params) File "/usr/local/lib/python3.10/site-packages/jax/core.py", line 2444, in bind return self.bind_with_trace(top_trace, args, params) File "/usr/local/lib/python3.10/site-packages/jax/core.py", line 332, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/usr/local/lib/python3.10/site-packages/jax/core.py", line 712, in process_primitive return primitive.impl(*tracers, **params) File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 115, in apply_primitive return compiled_fun(*args) File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 896, in _execute_compiled check_special(name, out_flat) File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 844, in check_special _check_special(name, buf.dtype, buf) File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 849, in _check_special raise FloatingPointError(f"invalid value (nan) encountered in {name}") jax._src.traceback_util.UnfilteredStackTrace: FloatingPointError: invalid value (nan) encountered in scan ```

For some reason, the trace seems to go no "deeper" than the top-level scan that creates the jitted part of the program.

It's interesting that jnp.nanmax works when jnp.max doesn't, despite the lack of nan values. nanmax also uses the permutation-symmetric convention you mentioned. Where can one find the gradient-computation code for these two functions, to compare them side by side? Perhaps carefully examining max's gradient-computation code could reveal some kind of edge/failure case?

Regarding a reproducible example, is it possible to get some representation of f with concrete values for all constants, while deep inside a jit? I tried setting up some code to manually detect when a nan gradient appears and run jax.debug.callback(callback_fn) if it does. However, playing around with f inside that callback leads to complaints about leaks:

``` *** jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[6] wrapped in a DynamicJaxprTracer to escape the scope of the transformation. JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state. The function being traced when the value leaked was at [...] traced for scan. ------------------------------ The leaked intermediate value was created on line [...]. ------------------------------ When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were: ------------------------------ [...] ------------------------------ To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError ```

I also tried explicitly passing the relevant functions and values to jax.debug.callback through the *args, but it then complains that the functions are not valid JAX types.

carlosgmartin commented 1 year ago

I solved the above callback issues by using something of the form

jax.debug.callback(functools.partial(callback_fn, *all_static_args), *all_dynamic_args)

However, something strange happens: When I compute and print the same gradient inside the callback, the nans disappear and the values are correct (still using jnp.max).