In my tests, using 8 H100 GPUs did not show any acceleration. However, when I ran the script kernels/triton/inference/col_major_moe_gemm/perf_test_moe.py , I did see a 2-3 times speedup. Later, I tested Column-Major MoE and the MoE from vllm on 2 H100 GPUs separately and found about a 25% speedup, but this acceleration was not observed on 8 H100 GPUs. Is this result reasonable?
To reproduce:
# Docker
docker run -it --gpus all --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3
# Download vllm library
pip install vllm==v0.4.0.post1
# Download vllm repo
git clone https://github.com/vllm-project/vllm.git
cd vllm/benchmarks/; git checkout tags/v0.4.0.post1
# Download Mixtral
from huggingface_hub import snapshot_download, login
hf_token = "You should use your own token!"
login(token=hf_token)
snapshot_download(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")
# Run benchmark with vllm MoE
# If you see error about quantization_param_path, just comment it and run again.
python benchmark_throughput.py \
--model=mistralai/Mixtral-8x7B-Instruct-v0.1 --input-len=256 --output-len=64 \
--tensor-parallel-size=8 --num-prompts 400 --worker-use-ray
# Run benchmark with column-major MoE
# Find out the MoE script,
vim /usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/fused_moe/fused_moe.py
# Copy and paste the script from Column-Major_fused_moe.py
python benchmark_throughput.py \
--model=mistralai/Mixtral-8x7B-Instruct-v0.1 --input-len=256 --output-len=64 \
--tensor-parallel-size=8 --num-prompts 400 --worker-use-ray
I got roughly the same throughput from the above benchmark.
The code for Column-Major_fused_moe.py
"""Column-major Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Dict, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm._C import ops
from vllm.logger import init_logger
from vllm.utils import is_hip
import time
MEASURE_TIME = False
logger = init_logger(__name__)
@triton.jit()
def col_major(pid,
m, n,
block_m: tl.constexpr, block_n: tl.constexpr):
grid_m = tl.cdiv(m, block_m)
grid_n = tl.cdiv(n, block_n)
pid_m = (pid % grid_n)
pid_n = pid // grid_m
return pid_m, pid_n
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_weight,
stride_token_id,
# Meta-parameters
block_m: tl.constexpr,
block_n: tl.constexpr,
block_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
block_m, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
pid = tl.program_id(axis=0)
pid_m, pid_n = col_major(pid,
EM, N,
block_m, block_n,)
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * block_m >= num_tokens_post_padded:
return
offs_token_id = pid_m * block_m + tl.arange(0, block_m)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * block_n + tl.arange(0, block_n)) % N
offs_k = tl.arange(0, block_k)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[block_m, block_n]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((block_m, block_n), dtype=tl.float32)
for k in range(0, tl.cdiv(K, block_k)):
# Load the next block of A and B, generate a mask by checking the K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * block_k),
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * block_k,
other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += block_k * stride_ak
b_ptrs += block_k * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * block_n + tl.arange(0, block_n)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1), ),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any]) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
EM = sorted_token_ids.shape[0]
N = B.shape[1]
grid = lambda META: (triton.cdiv(EM, META['block_m']) *
triton.cdiv(N, META['block_n']), )
fused_moe_kernel[grid](
A,
B,
C,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
topk_weights.stride(1), # New argument
sorted_token_ids.stride(0), # New argument
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
**config,
)
def get_config_file_name(E: int, N: int) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_")
return f"E={E},N={N},device_name={device_name}.json"
@functools.lru_cache
def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the fused_moe kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_config_file_name(E, N)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
f"Using configuration from {config_file_path} for MoE layer.")
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
return None
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor, # Not in Pytorch-Labs col_major
topk: int, # Pytorch-Labs pass topk_weights and topk_ids
renormalize: bool, # Not in Pytorch-Labs col_major
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None, # Not in Pytorch-Labs col_major
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape
E, N, _ = w1.shape
# print("hidden_states shape or BatchSize =",hidden_states.shape)
# print("w1 shape =",w1.shape)
# print("w2 shape =",w2.shape)
if MEASURE_TIME:
start_time = time.perf_counter()
if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels
topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
config = {
'block_m': 128,
'block_n': 128,
'block_k': 64,
}
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['block_m'], E)
# print(f"First MoE: a.shape = {hidden_states.shape}, b.shape = {w1.shape}, c.shape = {intermediate_cache1.shape},\n"
# f"Second MoE:a.shape = {intermediate_cache2.shape}, b.shape = {w2.shape}, c.shape = {intermediate_cache3.shape}")
invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, False,
topk_ids.shape[1], config)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, True, 1,
config)
if inplace:
result = torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states)
else:
result torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
if MEASURE_TIME:
end_time = time.perf_counter()
elapsed_time = (end_time - start_time) * 1_000 # Convert to milliseconds
print(f"Batch size = {M}, elapsed_time = {elapsed_time} ms")
return result
In my tests, using 8 H100 GPUs did not show any acceleration. However, when I ran the script
kernels/triton/inference/col_major_moe_gemm/perf_test_moe.py
, I did see a 2-3 times speedup. Later, I tested Column-Major MoE and the MoE from vllm on 2 H100 GPUs separately and found about a 25% speedup, but this acceleration was not observed on 8 H100 GPUs. Is this result reasonable?To reproduce:
I got roughly the same throughput from the above benchmark.
The code for Column-Major_fused_moe.py