databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

Cloning input `x` in `megablocks.layers.glu.SparseGLU` leads to different SDD outputs #115

Closed cmsflash closed 2 weeks ago

cmsflash commented 1 month ago

I am debugging a data-parallel forward mismatch when using megablocks (DP and non-DP give different forward results). During debugging, I tried to reproduce such difference minimally, and found that in SparseGLU.forward(), if you save x and w1 (by monkey-patching) right before https://github.com/databricks/megablocks/blob/f1a83bd55413b02b472696b719646cf22732d070/megablocks/layers/glu.py#L39, then put x and w1 through this line (x1 = stk.ops.sdd(x, w1.t(), topo)). The output will be different if we simply .clone() x (i.e. x1_clone = stk.ops.sdd(x.clone(), w1.t(), topo)) gives a wildly different output.

Below is a minimal reproduction:

import io
import os
import sys
import types

import stk
import torch
from megablocks.layers.activation_fn import act_fn
from megablocks.layers.arguments import Arguments
from megablocks.layers.dmoe import dMoE
from megablocks.layers.mlp import resolve_dtensor
from torch import distributed as dist

def nonempty(t: torch.Tensor, show_features: int = 8) -> torch.Tensor:
    """Treat the last dim as features, gather all non-empty features to a 2D tensor.

    Args:
        t: Tensor of shape (..., feature_count).
        show_features: The number of features to show. If None, show all.

    Returns:
        2D tensor of shape (nonempty_count, show_features).
    """
    if not isinstance(t, torch.Tensor):
        t = t.data
    t = t.reshape(-1, t.shape[-1])  # Reshape to (-1, d)
    t = t[(t != 0).any(dim=1)]  # Remove all rows that are all 0
    return t[..., :show_features]

def are_matrices_equal(a: stk.Matrix, b: stk.Matrix) -> bool:
    """Check if two matrices are equal."""
    return (
        (a.row_indices == b.row_indices).all()
        and (a.data == b.data).all()
        and (a.column_indices == b.column_indices).all()
        and (a.offsets == b.offsets).all()
        and (a.column_indices_t == b.column_indices_t).all()
        and (a.offsets_t == b.offsets_t).all()
        and (a.block_offsets_t == b.block_offsets_t).all()
    )

def glu_forward(self, x, topo):
    self.act_dict = {}
    if self.args.memory_optimized_mlp:
        raise NotImplementedError(
            "Memory optimized implementation not yet supported with GLU with sparse kernels."
        )

    w1, v1, w2 = (
        self.scale_grad(self.w1),
        self.scale_grad(self.v1),
        self.scale_grad(self.w2),
    )
    w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1), resolve_dtensor(w2)

    # Compute the GLU.
    self.act_dict["x"] = x
    self.act_dict["w1_resolved"] = w1
    x1 = stk.ops.sdd(x, w1.t(), topo)
    x2 = stk.ops.sdd(x, v1.t(), topo)

    activation_fn_out = act_fn(x1, self.args.activation_fn)
    x1 = stk.ops.mul(activation_fn_out, x2)

    output = stk.ops.dsd(x1, w2)
    return output

def try_sdd() -> tuple[bool, str]:
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12362"
    if not dist.is_initialized():
        dist.init_process_group(backend="gloo", rank=0, world_size=1)

    dim = 8

    megablocks_args = Arguments(
        hidden_size=dim,
        ffn_hidden_size=128,
        bias=False,
        return_bias=False,
        activation_fn=torch.nn.functional.silu,
        moe_num_experts=2,
        moe_top_k=1,
        moe_loss_weight=0.05,
        moe_normalize_expert_weights=1.0,
        moe_jitter_eps=0.0,
        mlp_type="glu",
        mlp_impl="sparse",
        moe_expert_model_parallelism=False,
        expert_parallel_group=None,
        fp16=False,
        bf16=True,
        device=torch.device("cuda"),
    )
    dmoe_ = dMoE(megablocks_args)
    dmoe_.experts.mlp.forward = types.MethodType(glu_forward, dmoe_.experts.mlp)

    input_ = torch.randn([1, 2, dim], dtype=torch.bfloat16, device=torch.device("cuda"))
    dmoe_(input_)

    topo = stk.Matrix(
        (256, 256),
        torch.empty((2, 128, 128), device="meta", dtype=torch.bfloat16),
        torch.tensor([0, 1], device="cuda:0", dtype=torch.int16),
        torch.tensor([0, 1], device="cuda:0", dtype=torch.int16),
        torch.tensor([0, 1, 2], device="cuda:0", dtype=torch.int32),
        torch.tensor([0, 1], device="cuda:0", dtype=torch.int16),
        torch.tensor([0, 1, 2], device="cuda:0", dtype=torch.int32),
        torch.tensor([0, 1], device="cuda:0", dtype=torch.int32),
    )

    x = dmoe_.experts.mlp.act_dict["x"]
    w = dmoe_.experts.mlp.act_dict["w1_resolved"]
    x1 = stk.ops.sdd(x, w.t(), topo)

    x_clone = x.clone()
    x1_clone = stk.ops.sdd(x_clone, w.t(), topo)
    equal = are_matrices_equal(x1_clone, x1)

    with io.StringIO() as output:
        sys.stdout = output

        print("Input X is the same:", (x_clone == x).all())
        print("SDD output is the same:", equal)
        print()

        print("Breakdown of SDD output:")
        print("Shape:", x1_clone.shape == x1.shape)
        print("Data:", x1_clone.data.allclose(x1.data))
        print("Row indices:", x1_clone.row_indices == x1.row_indices)
        print("Column indices:", x1_clone.column_indices == x1.column_indices)
        print("Offsets:", x1_clone.offsets == x1.offsets)
        print("Block offsets:", x1_clone.block_offsets_t == x1.block_offsets_t)
        print()

        print("Breakdown of SDD output data:")
        print("Nonempty elements:", nonempty(x1).shape, nonempty(x1_clone).shape)
        print("Per-row mean:", nonempty(x1).mean(dim=1), nonempty(x1_clone).mean(dim=1))
        print(
            "Cross-equality:",
            (nonempty(x1)[None, :, :] - nonempty(x1_clone)[:, None, :]).sum(-1),
        )

        sys.stdout = sys.__stdout__
        output_str = output.getvalue()

    return equal, output_str

def main() -> None:
    repetition_count = 10
    for trial_id in range(repetition_count):
        equal, output_str = try_sdd()
        if not equal:
            print(f"Trial {trial_id} failed.")
            print(output_str)
            return
    print(f"All {repetition_count} repetitions passed.")

if __name__ == "__main__":
    main()

My relevant environment info:

$ pipdeptree -p megablocks
megablocks==0.5.1
├── stanford-stk [required: >=0.0.6, installed: 0.7.0]
│   └── triton [required: >=2.1.0, installed: 2.2.0]
│       └── filelock [required: Any, installed: 3.13.4]
└── triton [required: >=2.1.0, installed: 2.2.0]
    └── filelock [required: Any, installed: 3.13.4]

$ pipdeptree -p torch
torch==2.1.0+cu121py311stripe
├── filelock [required: Any, installed: 3.13.4]
├── fsspec [required: Any, installed: 2023.10.0]
├── Jinja2 [required: Any, installed: 3.1.3]
│   └── MarkupSafe [required: >=2.0, installed: 2.1.5]
├── networkx [required: Any, installed: 3.3]
├── sympy [required: Any, installed: 1.12]
│   └── mpmath [required: >=0.19, installed: 1.3.0]
└── typing_extensions [required: Any, installed: 4.11.0]

$ python -V
Python 3.11.7

$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 20.04.6 LTS
Release:        20.04
Codename:       focal
mvpatel2000 commented 1 month ago

Hm... I'm less familiar with STK, @tgale96 any insights?

cmsflash commented 2 weeks ago

I managed to make a more minimal reproduction that is independent of MegaBlocks and only depends on STK, hence moving this issue to https://github.com/stanford-futuredata/stk/issues/11.