Open jramapuram opened 5 months ago
I've also seen some issues with AMP. I think theres something missing somewhere... but all the functions seem wrapped to me?
@mvpatel2000 : this can be worked around for moe.MoE
by force casting moe.to(torch.float32)
and AMP works fine. When doing the same with dmoe.dMoE
I get a triton error:
File "/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/miniconda/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/miniconda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 140, in decorate_bwd
return bwd(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/megablocks/layers/mlp.py", line 270, in backward
stk.backend.triton_kernels.sdd(
File "/miniconda/lib/python3.10/site-packages/stk/backend/triton_kernels.py", line 336, in sdd
_sdd_kernel[grid](
File "/miniconda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 114, in run
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
File "", line 63, in _sdd_kernel
File "/miniconda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile
next_module = compile_kernel(module)
File "/miniconda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 381, in
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
File "/miniconda/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 26:25: ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# do matrix multiplication
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(A)
b = tl.load(B)
acc += tl.dot(a, b)
Hi there,
Great work with dMoE! I'm trying to test dMoE with regular DDP + pytorch AMP(BF16) and I get the following error:
I'm just wrapping your exisiting
dmoe.dMoE(args)
logic.Is this something that is currently unsupported? If I force the entire network to BF16 then everything works fine.