databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

AMP + BF16 failing #95

Open jramapuram opened 5 months ago

jramapuram commented 5 months ago

Hi there,

Great work with dMoE! I'm trying to test dMoE with regular DDP + pytorch AMP(BF16) and I get the following error:

    optimizer_state["found_inf_per_device"] = self._unscale_grads_(
  File "/miniconda/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 248, in _unscale_grads_
    torch._amp_foreach_non_finite_check_and_unscale_(

I'm just wrapping your exisiting dmoe.dMoE(args) logic.

Is this something that is currently unsupported? If I force the entire network to BF16 then everything works fine.

mvpatel2000 commented 5 months ago

I've also seen some issues with AMP. I think theres something missing somewhere... but all the functions seem wrapped to me?

jramapuram commented 4 months ago

@mvpatel2000 : this can be worked around for moe.MoE by force casting moe.to(torch.float32) and AMP works fine. When doing the same with dmoe.dMoE I get a triton error:

  File "/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/miniconda/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/miniconda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 140, in decorate_bwd
    return bwd(*args, **kwargs)
  File "/miniconda/lib/python3.10/site-packages/megablocks/layers/mlp.py", line 270, in backward
    stk.backend.triton_kernels.sdd(
  File "/miniconda/lib/python3.10/site-packages/stk/backend/triton_kernels.py", line 336, in sdd
    _sdd_kernel[grid](
  File "/miniconda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 114, in run
    ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
  File "", line 63, in _sdd_kernel
  File "/miniconda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile
    next_module = compile_kernel(module)
  File "/miniconda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 381, in 
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
  File "/miniconda/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 26:25:    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
    # do matrix multiplication
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a = tl.load(A)
        b = tl.load(B)
        acc += tl.dot(a, b)