Closed carmocca closed 2 months ago
@carmocca This implementation is only intended to be used for processing one token at a time (or at least, few tokens).
For prefill, you don't want to use this implementation strategy. You instead want to implement it as a "fully dense" matrix multiply followed by gathering the appropriate tensors.
I think it should be something like this. Note that I haven't tested this for correctness yet! (although it does run).
import torch
import torch.nn as nn
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
class ConditionalLinear(nn.Module):
def __init__(self, num_experts, in_features, out_features):
super().__init__()
self.w = nn.Parameter(torch.empty(num_experts, out_features, in_features))
self.n_experts = num_experts
def forward(self, x, expert_indices):
if expert_indices.shape[0] <= 2:
w_weights = self.w[expert_indices].view(-1, *self.w.shape[-2:]) # [T, A, O, I]
return torch.einsum("ti, toi -> to", x, w_weights)
else:
dense_out = torch.einsum("ti, eoi -> teo", x, self.w)
one_hot_indices = torch.nn.functional.one_hot(expert_indices.view(-1), num_classes=self.n_experts).to(dtype=dense_out.dtype)
return torch.einsum("teo, te -> to", dense_out, one_hot_indices)
class ConditionalFeedForward(nn.Module):
def __init__(self):
super().__init__()
self.w1 = ConditionalLinear(n_expert, n_embd, intermediate_size)
self.w2 = ConditionalLinear(n_expert, intermediate_size, n_embd)
self.w3 = ConditionalLinear(n_expert, n_embd, intermediate_size)
def forward(self, x, expert_indices):
x = x.unsqueeze(1).expand(x.shape[0], expert_indices.shape[-1], x.shape[-1])
x = x.reshape(-1, x.shape[-1])
x1 = torch.nn.functional.silu(self.w1(x, expert_indices))
x3 = self.w3(x, expert_indices)
expert_outs = self.w2((x1 * x3), expert_indices)
return expert_outs
class MOEFeedForward(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gate = nn.Linear(n_embd, n_expert, bias=False)
self.cond_ffn = ConditionalFeedForward()
self.dim = n_embd
def forward(self, x):
x = x.view(-1, self.dim)
# T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
# x: [T, D]
scores = self.gate(x) # [T, E]
expert_weights = torch.nn.functional.softmax(scores, dim=-1)
expert_weights, expert_indices = torch.topk(expert_weights, n_expert_per_token, dim=-1) # [T, A], [T, A]
expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
expert_outs = self.cond_ffn(x, expert_indices)
return torch.einsum('tai,ta -> ti', expert_outs.view(-1, n_expert_per_token, expert_outs.shape[-1]), expert_weights)
def bench(f, name=None, iters=100, warmup=5):
import time
for _ in range(warmup):
f()
torch.cuda.synchronize()
begin = time.time()
for _ in range(iters):
f()
torch.cuda.synchronize()
us_per_iter = (time.time() - begin) * 1e6 / iters
res = us_per_iter if name is None else f"{name}: {us_per_iter}us"
print(res)
return res
n_embd = 4096
n_expert = 8
intermediate_size = 14336
n_expert_per_token = 2
torch.set_float32_matmul_precision("high")
device = torch.device("cuda")
B = 1
torch.set_default_dtype(torch.bfloat16)
for seq_len in [1, 4096]:
x = torch.randn(B, seq_len, n_embd, device=device)
with torch.no_grad():
with device:
m = MOEFeedForward()
print("MODEL", torch.cuda.max_memory_allocated() / 1e9)
torch.cuda.reset_peak_memory_stats()
bench(lambda: m(x), name="eager")
print(torch.cuda.max_memory_allocated() / 1e9)
torch.cuda.reset_peak_memory_stats()
cm = torch.compile(m)
bench(lambda: cm(x), name="compiled")
print(torch.cuda.max_memory_allocated() / 1e9)
Thank you for clarifying!
This is so awesome @Chillee !
QQ, How does this work if I am already using cuda graphs? Does torch.compile with mode=reduce-overheads
works?
I noticed if I do torch.compile
with reduce-oveheaad
it raises errors like:
RuntimeError: Cannot call CUDAGeneratorImpl::current_seed during CUDA graph capture. If you need this call to be captured, please file an issue. Current cudaStreamCaptureStatus: cudaStreamCaptureStatusActive
But running with default seems to be working.
I'm trying to benchmark the
ConditionalFeedForward
implementation for Mixtral.The code below is taken from https://github.com/pytorch-labs/gpt-fast/blob/main/mixtral-moe/model.py#L187-L201
This fails with an extreme attempt to allocate
This makes sense if you consider that this will allocate a new tensor:
What am I missing? Or is this only intended to be used with tensor parallelism?
Environment
torch==2.3.0.dev20240301+cu121