alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.05k stars 353 forks source link

Error when running pipeshard parallel: `assert len(gradients) == len(microbatch_bound.invars)` #829

Closed frankxyy closed 1 year ago

frankxyy commented 1 year ago

I run pipeshard parallel for vit-large model. The relevant code is:

method = alpa.PipeshardParallel(num_micro_batches=training_args.num_micro_batches,
                    layer_option=alpa.AutoLayerOption(layer_num=2),
                                                stage_option="auto")
p_train_step = alpa.parallelize(train_step, method=method)

Note that I set num_micro_batches to 1.

After running, the error msg below is thrown out:

Traceback (most recent call last):
  File "run_image_classification.py", line 590, in <module>
    main()
  File "run_image_classification.py", line 555, in main
    state, metrics = p_train_step(state, batch)
  File "/data1/home/xuyangyang/alpa/alpa/pipeline_parallel/compile_executable.py", line 93, in compile_pipeshard_executable
    pipeshard_config = compile_pipeshard_executable_internal(
  File "/data1/home/xuyangyang/alpa/alpa/pipeline_parallel/compile_executable.py", line 170, in compile_pipeshard_executable_internal
    apply_grad_global_info) = _slice_apply_grad_for_stage_construction(
  File "/data1/home/xuyangyang/alpa/alpa/pipeline_parallel/compile_executable.py", line 393, in _slice_apply_grad_for_stage_construction
    _) = process_apply_gradient(apply_grad_jaxpr, microbatch_bound,
  File "/data1/home/xuyangyang/alpa/alpa/pipeline_parallel/apply_grad.py", line 393, in process_apply_gradient
    assert len(gradients) == len(microbatch_bound.invars)
AssertionError

I debug to the error-happening line of code and find that length of gradients is 392 and length of microbatch_bound.invars is 393. The value stored in is:

(Pdb) len(gradients)
392
(Pdb) len(microbatch_bound.invars)
393
(Pdb) gradients
[luf, lug, luh, lui, luj, luk, lul, lum, lun, luo, lup, luq, lur, lus, lut, luu, luv, luw, lux, luy, luz, lva, lvb, lvc, lvd, lve, lvf, lvg, lvh, lvi, lvj, lvk, lvl, lvm, lvn, lvo, lvp, lvq, lvr, lvs, lvt, lvu, lvv, lvw, lvx, lvy, lvz, lwa, lwb, lwc, lwd, lwe, lwf, lwg, lwh, lwi, lwj, lwk, lwl, lwm, lwn, lwo, lwp, lwq, lwr, lws, lwt, lwu, lwv, lww, lwx, lwy, lwz, lxa, lxb, lxc, lxd, lxe, lxf, lxg, lxh, lxi, lxj, lxk, lxl, lxm, lxn, lxo, lxp, lxq, lxr, lxs, lxt, lxu, lxv, lxw, lxx, lxy, lxz, lya, lyb, lyc, lyd, lye, lyf, lyg, lyh, lyi, lyj, lyk, lyl, lym, lyn, lyo, lyp, lyq, lyr, lys, lyt, lyu, lyv, lyw, lyx, lyy, lyz, lza, lzb, lzc, lzd, lze, lzf, lzg, lzh, lzi, lzj, lzk, lzl, lzm, lzn, lzo, lzp, lzq, lzr, lzs, lzt, lzu, lzv, lzw, lzx, lzy, lzz, maa, mab, mac, mad, mae, maf, mag, mah, mai, maj, mak, mal, mam, man, mao, map, maq, mar, mas, mat, mau, mav, maw, max, may, maz, mba, mbb, mbc, mbd, mbe, mbf, mbg, mbh, mbi, mbj, mbk, mbl, mbm, mbn, mbo, mbp, mbq, mbr, mbs, mbt, mbu, mbv, mbw, mbx, mby, mbz, mca, mcb, mcc, mcd, mce, mcf, mcg, mch, mci, mcj, mck, mcl, mcm, mcn, mco, mcp, mcq, mcr, mcs, mct, mcu, mcv, mcw, mcx, mcy, mcz, mda, mdb, mdc, mdd, mde, mdf, mdg, mdh, mdi, mdj, mdk, mdl, mdm, mdn, mdo, mdp, mdq, mdr, mds, mdt, mdu, mdv, mdw, mdx, mdy, mdz, mea, meb, mec, med, mee, mef, meg, meh, mei, mej, mek, mel, mem, men, meo, mep, meq, mer, mes, met, meu, mev, mew, mex, mey, mez, mfa, mfb, mfc, mfd, mfe, mff, mfg, mfh, mfi, mfj, mfk, mfl, mfm, mfn, mfo, mfp, mfq, mfr, mfs, mft, mfu, mfv, mfw, mfx, mfy, mfz, mga, mgb, mgc, mgd, mge, mgf, mgg, mgh, mgi, mgj, mgk, mgl, mgm, mgn, mgo, mgp, mgq, mgr, mgs, mgt, mgu, mgv, mgw, mgx, mgy, mgz, mha, mhb, mhc, mhd, mhe, mhf, mhg, mhh, mhi, mhj, mhk, mhl, mhm, mhn, mho, mhp, mhq, mhr, mhs, mht, mhu, mhv, mhw, mhx, mhy, mhz, mia, mib, mic, mid, mie, mif, mig, mih, mii, mij, mik, mil, mim, min, mio, mip, miq, mir, mis, mit, miu, miv, miw, mix, miy, miz, mja, mjb, mjc, mjd, mje, mjf, mjg]
(Pdb) microbatch_bound
_:f32[] a:f32[10] b:f32[2048,10] c:f32[1,1,2048] d:f32[2048] e:f32[16,16,3,2048]
  f:f32[1,197,2048] g:f32[2048] h:f32[2048,2048] i:f32[2048] j:f32[2048,2048] k:f32[2048]
  l:f32[2048,2048] m:f32[2048] n:f32[2048,2048] o:f32[8192] p:f32[2048,8192] q:f32[2048]
  r:f32[2048] s:f32[2048] t:f32[2048] u:f32[2048] v:f32[8192,2048] w:f32[2048] x:f32[2048,2048]
  y:f32[2048] z:f32[2048,2048] ba:f32[2048] bb:f32[2048,2048] bc:f32[2048] bd:f32[2048,2048]
  be:f32[8192] bf:f32[2048,8192] bg:f32[2048] bh:f32[2048] bi:f32[2048] bj:f32[2048]
  bk:f32[2048] bl:f32[8192,2048] bm:f32[2048] bn:f32[2048,2048] bo:f32[2048] bp:f32[2048,2048]
  bq:f32[2048] br:f32[2048,2048] bs:f32[2048] bt:f32[2048,2048] bu:f32[8192] bv:f32[2048,8192]
  bw:f32[2048] bx:f32[2048] by:f32[2048] bz:f32[2048] ca:f32[2048] cb:f32[8192,2048]
  cc:f32[2048] cd:f32[2048,2048] ce:f32[2048] cf:f32[2048,2048] cg:f32[2048] ch:f32[2048,2048]
  ci:f32[2048] cj:f32[2048,2048] ck:f32[8192] cl:f32[2048,8192] cm:f32[2048] cn:f32[2048]
  co:f32[2048] cp:f32[2048] cq:f32[2048] cr:f32[8192,2048] cs:f32[2048] ct:f32[2048,2048]
  cu:f32[2048] cv:f32[2048,2048] cw:f32[2048] cx:f32[2048,2048] cy:f32[2048] cz:f32[2048,2048]
  da:f32[8192] db:f32[2048,8192] dc:f32[2048] dd:f32[2048] de:f32[2048] df:f32[2048]
  dg:f32[2048] dh:f32[8192,2048] di:f32[2048] dj:f32[2048,2048] dk:f32[2048] dl:f32[2048,2048]
  dm:f32[2048] dn:f32[2048,2048] do:f32[2048] dp:f32[2048,2048] dq:f32[8192] dr:f32[2048,8192]
  ds:f32[2048] dt:f32[2048] du:f32[2048] dv:f32[2048] dw:f32[2048] dx:f32[8192,2048]
  dy:f32[2048] dz:f32[2048,2048] ea:f32[2048] eb:f32[2048,2048] ec:f32[2048] ed:f32[2048,2048]
  ee:f32[2048] ef:f32[2048,2048] eg:f32[8192] eh:f32[2048,8192] ei:f32[2048] ej:f32[2048]
  ek:f32[2048] el:f32[2048] em:f32[2048] en:f32[8192,2048] eo:f32[2048] ep:f32[2048,2048]
  eq:f32[2048] er:f32[2048,2048] es:f32[2048] et:f32[2048,2048] eu:f32[2048] ev:f32[2048,2048]
  ew:f32[8192] ex:f32[2048,8192] ey:f32[2048] ez:f32[2048] fa:f32[2048] fb:f32[2048]
  fc:f32[2048] fd:f32[8192,2048] fe:f32[2048] ff:f32[2048,2048] fg:f32[2048] fh:f32[2048,2048]
  fi:f32[2048] fj:f32[2048,2048] fk:f32[2048] fl:f32[2048,2048] fm:f32[8192] fn:f32[2048,8192]
  fo:f32[2048] fp:f32[2048] fq:f32[2048] fr:f32[2048] fs:f32[2048] ft:f32[8192,2048]
  fu:f32[2048] fv:f32[2048,2048] fw:f32[2048] fx:f32[2048,2048] fy:f32[2048] fz:f32[2048,2048]
  ga:f32[2048] gb:f32[2048,2048] gc:f32[8192] gd:f32[2048,8192] ge:f32[2048] gf:f32[2048]
  gg:f32[2048] gh:f32[2048] gi:f32[2048] gj:f32[8192,2048] gk:f32[2048] gl:f32[2048,2048]
  gm:f32[2048] gn:f32[2048,2048] go:f32[2048] gp:f32[2048,2048] gq:f32[2048] gr:f32[2048,2048]
  gs:f32[8192] gt:f32[2048,8192] gu:f32[2048] gv:f32[2048] gw:f32[2048] gx:f32[2048]
  gy:f32[2048] gz:f32[8192,2048] ha:f32[2048] hb:f32[2048,2048] hc:f32[2048] hd:f32[2048,2048]
  he:f32[2048] hf:f32[2048,2048] hg:f32[2048] hh:f32[2048,2048] hi:f32[8192] hj:f32[2048,8192]
  hk:f32[2048] hl:f32[2048] hm:f32[2048] hn:f32[2048] ho:f32[2048] hp:f32[8192,2048]
  hq:f32[2048] hr:f32[2048,2048] hs:f32[2048] ht:f32[2048,2048] hu:f32[2048] hv:f32[2048,2048]
  hw:f32[2048] hx:f32[2048,2048] hy:f32[8192] hz:f32[2048,8192] ia:f32[2048] ib:f32[2048]
  ic:f32[2048] id:f32[2048] ie:f32[2048] if:f32[8192,2048] ig:f32[2048] ih:f32[2048,2048]
  ii:f32[2048] ij:f32[2048,2048] ik:f32[2048] il:f32[2048,2048] im:f32[2048] in:f32[2048,2048]
  io:f32[8192] ip:f32[2048,8192] iq:f32[2048] ir:f32[2048] is:f32[2048] it:f32[2048]
  iu:f32[2048] iv:f32[8192,2048] iw:f32[2048] ix:f32[2048,2048] iy:f32[2048] iz:f32[2048,2048]
  ja:f32[2048] jb:f32[2048,2048] jc:f32[2048] jd:f32[2048,2048] je:f32[8192] jf:f32[2048,8192]
  jg:f32[2048] jh:f32[2048] ji:f32[2048] jj:f32[2048] jk:f32[2048] jl:f32[8192,2048]
  jm:f32[2048] jn:f32[2048,2048] jo:f32[2048] jp:f32[2048,2048] jq:f32[2048] jr:f32[2048,2048]
  js:f32[2048] jt:f32[2048,2048] ju:f32[8192] jv:f32[2048,8192] jw:f32[2048] jx:f32[2048]
  jy:f32[2048] jz:f32[2048] ka:f32[2048] kb:f32[8192,2048] kc:f32[2048] kd:f32[2048,2048]
  ke:f32[2048] kf:f32[2048,2048] kg:f32[2048] kh:f32[2048,2048] ki:f32[2048] kj:f32[2048,2048]
  kk:f32[8192] kl:f32[2048,8192] km:f32[2048] kn:f32[2048] ko:f32[2048] kp:f32[2048]
  kq:f32[2048] kr:f32[8192,2048] ks:f32[2048] kt:f32[2048,2048] ku:f32[2048] kv:f32[2048,2048]
  kw:f32[2048] kx:f32[2048,2048] ky:f32[2048] kz:f32[2048,2048] la:f32[8192] lb:f32[2048,8192]
  lc:f32[2048] ld:f32[2048] le:f32[2048] lf:f32[2048] lg:f32[2048] lh:f32[8192,2048]
  li:f32[2048] lj:f32[2048,2048] lk:f32[2048] ll:f32[2048,2048] lm:f32[2048] ln:f32[2048,2048]
  lo:f32[2048] lp:f32[2048,2048] lq:f32[8192] lr:f32[2048,8192] ls:f32[2048] lt:f32[2048]
  lu:f32[2048] lv:f32[2048] lw:f32[2048] lx:f32[8192,2048] ly:f32[2048] lz:f32[2048,2048]
  ma:f32[2048] mb:f32[2048,2048] mc:f32[2048] md:f32[2048,2048] me:f32[2048] mf:f32[2048,2048]
  mg:f32[8192] mh:f32[2048,8192] mi:f32[2048] mj:f32[2048] mk:f32[2048] ml:f32[2048]
  mm:f32[2048] mn:f32[8192,2048] mo:f32[2048] mp:f32[2048,2048] mq:f32[2048] mr:f32[2048,2048]
  ms:f32[2048] mt:f32[2048,2048] mu:f32[2048] mv:f32[2048,2048] mw:f32[8192] mx:f32[2048,8192]
  my:f32[2048] mz:f32[2048] na:f32[2048] nb:f32[2048] nc:f32[2048] nd:f32[8192,2048]
  ne:f32[2048] nf:f32[2048,2048] ng:f32[2048] nh:f32[2048,2048] ni:f32[2048] nj:f32[2048,2048]
  nk:f32[2048] nl:f32[2048,2048] nm:f32[8192] nn:f32[2048,8192] no:f32[2048] np:f32[2048]
  nq:f32[2048] nr:f32[2048] ns:f32[2048] nt:f32[8192,2048] nu:f32[2048] nv:f32[2048,2048]
  nw:f32[2048] nx:f32[2048,2048] ny:f32[2048] nz:f32[2048,2048] oa:f32[2048] ob:f32[2048,2048]
  oc:f32[8192] od:f32[2048,8192] oe:f32[2048] of:f32[2048] og:f32[2048] oh:f32[2048]
  oi:f32[2048] oj:f32[8192,2048] ok:f32[2048] ol:f32[2048,2048] om:f32[2048] on:f32[2048,2048]
  oo:f32[2048] op:f32[2048,2048] oq:f32[2048] or:f32[2048,2048] os:f32[8192] ot:f32[2048,8192]
  ou:f32[2048] ov:f32[2048] ow:f32[2048] ox:f32[2048] oy:f32[2048] oz:f32[8192,2048]
  pa:f32[2048] pb:f32[2048] = pipeline_marker[mark_type=grad name=grad] pc pd pe
  pf pg ph pi pj pk pl pm pn po pp pq pr ps pt pu pv pw px py pz qa qb qc qd qe qf
  qg qh qi qj qk ql qm qn qo qp qq qr qs qt qu qv qw qx qy qz ra rb rc rd re rf rg
  rh ri rj rk rl rm rn ro rp rq rr rs rt ru rv rw rx ry rz sa sb sc sd se sf sg sh
  si sj sk sl sm sn so sp sq sr ss st su sv sw sx sy sz ta tb tc td te tf tg th ti
  tj tk tl tm tn to tp tq tr ts tt tu tv tw tx ty tz ua ub uc ud ue uf ug uh ui uj
  uk ul um un uo up uq ur us ut uu uv uw ux uy uz va vb vc vd ve vf vg vh vi vj vk
  vl vm vn vo vp vq vr vs vt vu vv vw vx vy vz wa wb wc wd we wf wg wh wi wj wk wl
  wm wn wo wp wq wr ws wt wu wv ww wx wy wz xa xb xc xd xe xf xg xh xi xj xk xl xm
  xn xo xp xq xr xs xt xu xv xw xx xy xz ya yb yc yd ye yf yg yh yi yj yk yl ym yn
  yo yp yq yr ys yt yu yv yw yx yy yz za zb zc zd ze zf zg zh zi zj zk zl zm zn zo
  zp zq zr zs zt zu zv zw zx zy zz baa bab bac bad bae baf bag bah bai baj bak bal
  bam ban bao bap baq bar bas bat bau bav baw bax bay baz bba bbb bbc bbd bbe bbf
  bbg bbh bbi bbj bbk bbl bbm bbn bbo bbp bbq bbr bbs bbt bbu bbv bbw bbx bby bbz
  bca bcb bcc bcd bce bcf bcg bch bci bcj bck bcl bcm bcn bco bcp bcq bcr bcs bct
  bcu bcv bcw bcx bcy bcz bda bdb bdc bdd bde bdf bdg bdh bdi bdj bdk bdl bdm bdn
  bdo bdp bdq bdr bds bdt bdu bdv bdw bdx bdy bdz bea beb bec bed bee
(Pdb)
frankxyy commented 1 year ago

image It seems that there exists a DropVar in the microbatch_bound.outvars.

frankxyy commented 1 year ago

Plus: This error happens for implementation of train_step like this:

def train_step(state, batch):

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch, params=params, train=True)[0]
            loss = loss_fn(logits, labels)
            return loss

        grad_fn = alpa.value_and_grad(compute_loss)
        loss, grad = grad_fn(state.params)
        new_state = state.apply_gradients(grads=grad)

        metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}

        return new_state, metrics

but this error not happens for implementation of train_step like this:

def train_step(state, batch):

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch, params=params, train=True)[0]
            loss = loss_fn(logits, labels)
            return loss

        grad_fn = alpa.value_and_grad(compute_loss)
        loss, grad = grad_fn(state.params)
        new_state = state.apply_gradients(grads=grad)

        metrics = {"loss": loss}

        return new_state, None
merrymercy commented 1 year ago

@ZYHowell Could you take a look?

ZYHowell commented 1 year ago

Seems like this assertion is removed in https://github.com/alpa-projects/alpa/pull/681. Could you please try the nightly alpa?

merrymercy commented 1 year ago

closed due to inactivity