shawntan / scattermoe

Triton-based implementation of Sparse Mixture of Experts.
Apache License 2.0
150 stars 10 forks source link

Accuracy Issues #5

Closed jeromeku closed 3 months ago

jeromeku commented 3 months ago

@shawntan

Issue

Testing scattermoe.mlp.GLUMLP against reference HF MixtralSparseMLP seems to give significant discrepancies in output.

Repro

For batch size=1, sequence length=5, and hidden dim=4096, this gives errors on the order of .1 - .5 for float16 / float32.

Env:

Full script:

import torch
import torch.nn.functional as F
from accelerate import init_empty_weights
from configuration_mixtral import MixtralConfig as ScatterMixtralConfig
from modeling_mixtral import MixtralModel as ScatterMixtralModel
from transformers import AutoConfig, MixtralModel
from transformers.models.mixtral.modeling_mixtral import (
    MixtralBlockSparseTop2MLP,
    MixtralSparseMoeBlock,
)

import scattermoe

def convert_mixtral_block_to_scattermoe(
    mixtral_moe_block: MixtralSparseMoeBlock,
) -> scattermoe.mlp.GLUMLP:

    # First merge gate and up proj weights for all experts, E x (intermediate_size * 2) x hidden_dim
    merged_w1w3_list = []
    for expert in mixtral_moe_block.experts:
        merged_w1w3_list.append(
            torch.cat([expert.w1.weight.T, expert.w3.weight.T], dim=1)
        )
    merged_w1w3 = torch.stack(merged_w1w3_list, dim=0).permute(0, 2, 1)

    # Stack w2 weights into single weight for output_experts, E x hidden_dim x intermediate_size
    merged_w2 = torch.stack(
        [expert.w2.weight for expert in mixtral_moe_block.experts], dim=0
    )

    act_fn = mixtral_moe_block.experts[0].act_fn

    scatter_mlp = scattermoe.mlp.GLUMLP(
        input_size=mixtral_moe_block.hidden_dim,
        hidden_size=mixtral_moe_block.ffn_dim,
        num_experts=mixtral_moe_block.num_experts,
        top_k=mixtral_moe_block.top_k,
        activation=act_fn,
    )
    # w1w3_before = scatter_mlp.experts.weight.mean()
    # w2_before = scatter_mlp.output_experts.weight.mean()
    # Check that the merged weights match expected shape
    assert scatter_mlp.experts.weight.shape == merged_w1w3.shape
    assert scatter_mlp.output_experts.weight.shape == merged_w2.shape

    # Copy weights into GluMLP
    scatter_mlp.experts.weight = torch.nn.Parameter(merged_w1w3)
    scatter_mlp.output_experts.weight = torch.nn.Parameter(merged_w2)

    # w1w3_after = scatter_mlp.experts.weight.mean()
    # w2_after = scatter_mlp.output_experts.weight.mean()

    # print(f"w1w3_before={w1w3_before}, w1w3_after={w1w3_after}")
    # print(f"w2_before={w2_before}, w2_after={w2_after}")
    return scatter_mlp

# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L843-L888
def router(hidden_states: torch.Tensor, router_gate: torch.nn.Linear, top_k: int):
    _, _, hidden_dim = hidden_states.shape
    hidden_states = hidden_states.view(-1, hidden_dim)

    router_logits = router_gate(hidden_states)

    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

    routing_weights = routing_weights.to(hidden_states.dtype)
    return routing_weights, selected_experts

# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L843-L888
def ref_forward(
    experts: list[MixtralBlockSparseTop2MLP],
    hidden_states: torch.Tensor,
    routing_weights: torch.Tensor,
    selected_experts: torch.Tensor,
):
    batch_size, sequence_length, hidden_dim = hidden_states.shape
    final_hidden_states = torch.zeros(
        (batch_size * sequence_length, hidden_dim),
        dtype=hidden_states.dtype,
        device=hidden_states.device,
    )
    hidden_states = hidden_states.view(-1, hidden_dim)

    expert_mask = torch.nn.functional.one_hot(
        selected_experts, num_classes=config.num_local_experts
    ).permute(2, 1, 0)

    for expert_idx in range(config.num_local_experts):
        expert_layer = experts[expert_idx]
        idx, top_x = torch.where(expert_mask[expert_idx])

        if top_x.shape[0] == 0:
            continue

        # in torch it is faster to index using lists than torch tensors
        top_x_list = top_x.tolist()
        idx_list = idx.tolist()

        # Index the correct hidden states and compute the expert hidden state for
        # the current expert. We need to make sure to multiply the output hidden
        # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
        current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
        current_hidden_states = (
            expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
        )

        # However `index_add_` only support torch tensors for indexing so we'll use
        # the `top_x` tensor here.
        final_hidden_states.index_add_(
            0, top_x, current_hidden_states.to(hidden_states.dtype)
        )

    final_hidden_states = final_hidden_states.reshape(
        batch_size, sequence_length, hidden_dim
    )

    return final_hidden_states

def test_scatter_mlp(
    config, batch_size=1, sequence_length=5, dtype=torch.float16, device="cuda"
):
    hidden_dim = config.hidden_size
    hidden_states = torch.randn(
        batch_size,
        sequence_length,
        hidden_dim,
        dtype=dtype,
        device=device,
    )

    top_k = config.num_experts_per_tok

    mixtral_moe_block = MixtralSparseMoeBlock(config).to(dtype).to(device)
    scatter_mlp = (
        convert_mixtral_block_to_scattermoe(mixtral_moe_block).to(dtype).to(device)
    )

    routing_weights, selected_experts = router(
        hidden_states, mixtral_moe_block.gate, top_k
    )

    ref_out = ref_forward(
        mixtral_moe_block.experts, hidden_states, routing_weights, selected_experts
    )
    scatter_out = scatter_mlp(hidden_states, routing_weights, selected_experts)
    print((ref_out - scatter_out).abs().max())

torch.manual_seed(42)
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
config = AutoConfig.from_pretrained(model_id)
with torch.no_grad():
    test_scatter_mlp(config)
casper-hansen commented 3 months ago

I observed similar difference in logits. There might be some slight precision issue in the kernels.

jeromeku commented 3 months ago

@casper-hansen

Running above script gives max abs error: .1848.

Also interested in running scattermoe with pretrained models per your PR #2.

I have a working implementation of converting pretrained models to scattermoe but the precision is a major concern (or just a misstep in my implementation).

shawntan commented 3 months ago

Working on the pretrained models right now.

I am observing the same thing, likely cause is the float32 router weights in the original implementation vs the casting to bfloat16 I did in my modification.

CanyonWind commented 3 months ago

@shawntan thanks for the insight, potentially this could be quickly verified right. Define and load the router in fp32. Also it seems that there's no need to cast back to bf16 as router only decides which experts will be activated. The experts' actual computations are still in bf16 anyhow. A conclusion on the discrepancy can be really beneficial.

shawntan commented 3 months ago

I checked, the discrepancy is about 0.01 in float32. I'm still studying the problem.

jeromeku commented 3 months ago

@shawntan

Can you point me to the script reproducing that result? Happy to help debug.

jeromeku commented 3 months ago

I checked, the discrepancy is about 0.01 in float32. I'm still studying the problem.

@shawntan Where is the script for reproducing this?

shawntan commented 3 months ago

I checked, the discrepancy is about 0.01 in float32. I'm still studying the problem.

@shawntan Where is the script for reproducing this?

import torch
import transformers
import gc
from mixtral.modeling_mixtral import MixtralModel, MixtralForCausalLM, MixtralSparseMoeBlock
from mixtral.configuration_mixtral import MixtralConfig
MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1"

if __name__ == "__main__":
    device = torch.device('cuda:0')
    dtype = torch.float32
    config = transformers.MixtralConfig.from_pretrained(MODEL_NAME)
    num_experts = config.num_local_experts

    x = torch.randn(8, 2048, 4096).to(dtype=dtype, device=device)

    sm_mlp = MixtralSparseMoeBlock(config).to(dtype=dtype, device=device)
    hf_mlp = transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock(config).to(dtype=dtype, device=device)
    hf_state_dict = hf_mlp.state_dict()
    for n, p in sm_mlp.named_parameters():
        if n in hf_state_dict:
            p.data[:] = hf_state_dict.pop(n)
        else:
            prefix, suffix = n.split('moe_mlp')
            for i in range(num_experts):
                if suffix == ".output_experts.weight":
                    w2_param_name = prefix + "experts.%d.w2.weight" % i
                    p.data[i, :, :] = hf_state_dict.pop(w2_param_name)
                else:
                    w1_param_name = prefix + "experts.%d.w1.weight" % i
                    w3_param_name = prefix + "experts.%d.w3.weight" % i
                    out_dim, in_dim = hf_state_dict[w1_param_name].size()
                    p.data[i, :out_dim, :] = hf_state_dict.pop(w3_param_name)
                    p.data[i, out_dim:, :] = hf_state_dict.pop(w1_param_name)

    hf_out = hf_mlp(x)[0]
    sm_out = sm_mlp(x)[0]
    diff = torch.abs(hf_out - sm_out)
    print("max diff:", diff.max())

I think the issue was the way I did the gating inside GLUMLP, your w3 and w1 concatenation order should be swapped. Sorry about that.

shawntan commented 3 months ago

Sorry for the erratic behaviour, I was trying to fix some really weird behaviour when running on lm-eval, and I still haven't figured it out. At first I thought the NaN issues I was getting was due to the precision problems in the kernel, so I was trying all manner of fixes to ensure the numerical errors were as low as possible. Turns out turning off tf32 makes it close to 0, but leaving it on isn't too bad either. It is a known effect of tf32 after all.

The issue seems to be weirder than that.

Asking here in case anyone knows: When running parallelized, everything happening on cuda:0 goes well, then on handing over to cuda:1, it's like the kernel never executes, the empty buffer I initialised gets passed back to me. If I switch it to torch.zeros instead of torch.empty, I get a zero buffer.

casper-hansen commented 3 months ago

@shawntan I have a bit of experience with things not going as expected when going from device 0 to 1. Usually, it's due to a lack of device guarding or lack of passing everything to the right device. However, these issues are incredibly hard to actually figure out and doing any print statements usually results in incorrectly reported values. That's my 2 cents on this issue.

Can you try using a with block in some of your modeling code to make sure?

with torch.cuda.device(hidden_states.device):

For reference, here is how AutoGPTQ calls it's Triton kernels: https://github.com/AutoGPTQ/AutoGPTQ/blob/74212f501396827772bf9d915b079bd6a419bd92/auto_gptq/nn_modules/triton_utils/kernels.py#L347-L374

shawntan commented 3 months ago

It works! Thanks @casper-hansen , great tip. I'll keep that in mind in future.