Closed cmsflash closed 2 weeks ago
Hm... I'm less familiar with STK, @tgale96 any insights?
I managed to make a more minimal reproduction that is independent of MegaBlocks and only depends on STK, hence moving this issue to https://github.com/stanford-futuredata/stk/issues/11.
I am debugging a data-parallel forward mismatch when using
megablocks
(DP and non-DP give different forward results). During debugging, I tried to reproduce such difference minimally, and found that inSparseGLU.forward()
, if you savex
andw1
(by monkey-patching) right before https://github.com/databricks/megablocks/blob/f1a83bd55413b02b472696b719646cf22732d070/megablocks/layers/glu.py#L39, then putx
andw1
through this line (x1 = stk.ops.sdd(x, w1.t(), topo)
). The output will be different if we simply.clone()
x (i.e.x1_clone = stk.ops.sdd(x.clone(), w1.t(), topo)
) gives a wildly different output.Below is a minimal reproduction:
My relevant environment info: