pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.34k stars 484 forks source link

On the memory usage of `ConditionalFeedForward` #149

Closed carmocca closed 2 months ago

carmocca commented 3 months ago

I'm trying to benchmark the ConditionalFeedForward implementation for Mixtral.

The code below is taken from https://github.com/pytorch-labs/gpt-fast/blob/main/mixtral-moe/model.py#L187-L201

import time

import torch
import torch.nn as nn

class ConditionalFeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Parameter(torch.empty(n_expert, intermediate_size, n_embd))
        self.w2 = nn.Parameter(torch.empty(n_expert, n_embd, intermediate_size))
        self.w3 = nn.Parameter(torch.empty(n_expert, intermediate_size, n_embd))

    def forward(self, x, expert_indices):
        w1_weights = self.w1[expert_indices] # [T, A, D, D]
        w3_weights = self.w3[expert_indices] # [T, A, D, D]
        w2_weights = self.w2[expert_indices]  # [T, A, D, D]
        x1 = torch.nn.functional.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
        x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
        expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
        return expert_outs

class MOEFeedForward(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.gate = nn.Linear(n_embd, n_expert, bias=False)
        self.cond_ffn = ConditionalFeedForward()
        self.dim = n_embd

    def forward(self, x):
        x = x.view(-1, self.dim)
        # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
        # x: [T, D]
        scores = self.gate(x) # [T, E]
        expert_weights = torch.nn.functional.softmax(scores, dim=-1)
        expert_weights, expert_indices = torch.topk(expert_weights, n_expert_per_token, dim=-1) # [T, A], [T, A]
        expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
        expert_outs = self.cond_ffn(x, expert_indices)
        return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)

def bench(f, name=None, iters=100, warmup=5):
    for _ in range(warmup):
        f()
    torch.cuda.synchronize()
    begin = time.time()
    for _ in range(iters):
        f()
    torch.cuda.synchronize()
    us_per_iter = (time.time() - begin) * 1e6 / iters
    res = us_per_iter if name is None else f"{name}: {us_per_iter}us"
    print(res)
    return res

n_embd = 4096
n_expert = 8
intermediate_size = 14336
n_expert_per_token = 2
torch.set_float32_matmul_precision("high")
device = torch.device("cuda")
B = 1
torch.set_default_dtype(torch.bfloat16)
x = torch.randn(B, 4096, n_embd, device=device)

with device:
    m = MOEFeedForward()
print("MODEL", torch.cuda.max_memory_allocated() / 1e9)
torch.cuda.reset_peak_memory_stats()

bench(lambda: m(x), name="eager")
print(torch.cuda.max_memory_allocated() / 1e9)
torch.cuda.reset_peak_memory_stats()

cm = torch.compile(m)
bench(lambda: cm(x), name="compiled")
print(torch.cuda.max_memory_allocated() / 1e9)

This fails with an extreme attempt to allocate

  File "/home/carlos/bench.py", line 15, in forward
    w1_weights = self.w1[expert_indices] # [T, A, D, D]
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 896.00 GiB. GPU 

This makes sense if you consider that this will allocate a new tensor:

>>> w = torch.randn(8, 16, 4)
>>> indices = torch.randint(0, 8, (32, 2))
>>> w.storage().data_ptr()
94871054609536
>>> w[indices].storage().data_ptr()
94871054583232

What am I missing? Or is this only intended to be used with tensor parallelism?

Environment

torch==2.3.0.dev20240301+cu121

Chillee commented 2 months ago

@carmocca This implementation is only intended to be used for processing one token at a time (or at least, few tokens).

For prefill, you don't want to use this implementation strategy. You instead want to implement it as a "fully dense" matrix multiply followed by gathering the appropriate tensors.

Chillee commented 2 months ago

I think it should be something like this. Note that I haven't tested this for correctness yet! (although it does run).

import torch
import torch.nn as nn
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True

class ConditionalLinear(nn.Module):
    def __init__(self, num_experts, in_features, out_features):
        super().__init__()
        self.w = nn.Parameter(torch.empty(num_experts, out_features, in_features))
        self.n_experts = num_experts

    def forward(self, x, expert_indices):
        if expert_indices.shape[0] <= 2:
            w_weights = self.w[expert_indices].view(-1, *self.w.shape[-2:]) # [T, A, O, I]
            return torch.einsum("ti, toi -> to", x, w_weights)
        else:
            dense_out = torch.einsum("ti, eoi -> teo", x, self.w)
            one_hot_indices = torch.nn.functional.one_hot(expert_indices.view(-1), num_classes=self.n_experts).to(dtype=dense_out.dtype)
            return torch.einsum("teo, te -> to", dense_out, one_hot_indices)

class ConditionalFeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = ConditionalLinear(n_expert, n_embd, intermediate_size)
        self.w2 = ConditionalLinear(n_expert, intermediate_size, n_embd)
        self.w3 = ConditionalLinear(n_expert, n_embd, intermediate_size)

    def forward(self, x, expert_indices):
        x = x.unsqueeze(1).expand(x.shape[0], expert_indices.shape[-1], x.shape[-1])
        x = x.reshape(-1, x.shape[-1])
        x1 = torch.nn.functional.silu(self.w1(x, expert_indices))
        x3 = self.w3(x, expert_indices)
        expert_outs = self.w2((x1 * x3), expert_indices)
        return expert_outs

class MOEFeedForward(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.gate = nn.Linear(n_embd, n_expert, bias=False)
        self.cond_ffn = ConditionalFeedForward()
        self.dim = n_embd

    def forward(self, x):
        x = x.view(-1, self.dim)
        # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
        # x: [T, D]
        scores = self.gate(x) # [T, E]
        expert_weights = torch.nn.functional.softmax(scores, dim=-1)
        expert_weights, expert_indices = torch.topk(expert_weights, n_expert_per_token, dim=-1) # [T, A], [T, A]
        expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
        expert_outs = self.cond_ffn(x, expert_indices)
        return torch.einsum('tai,ta -> ti', expert_outs.view(-1, n_expert_per_token, expert_outs.shape[-1]), expert_weights)

def bench(f, name=None, iters=100, warmup=5):
    import time
    for _ in range(warmup):
        f()
    torch.cuda.synchronize()
    begin = time.time()
    for _ in range(iters):
        f()
    torch.cuda.synchronize()
    us_per_iter = (time.time() - begin) * 1e6 / iters
    res = us_per_iter if name is None else f"{name}: {us_per_iter}us"
    print(res)
    return res

n_embd = 4096
n_expert = 8
intermediate_size = 14336
n_expert_per_token = 2
torch.set_float32_matmul_precision("high")
device = torch.device("cuda")
B = 1
torch.set_default_dtype(torch.bfloat16)
for seq_len in [1, 4096]:
    x = torch.randn(B, seq_len, n_embd, device=device)

    with torch.no_grad():
        with device:
            m = MOEFeedForward()
        print("MODEL", torch.cuda.max_memory_allocated() / 1e9)
        torch.cuda.reset_peak_memory_stats()

        bench(lambda: m(x), name="eager")
        print(torch.cuda.max_memory_allocated() / 1e9)
        torch.cuda.reset_peak_memory_stats()

        cm = torch.compile(m)
        bench(lambda: cm(x), name="compiled")
        print(torch.cuda.max_memory_allocated() / 1e9)
carmocca commented 2 months ago

Thank you for clarifying!

rickyyx commented 2 months ago

This is so awesome @Chillee !

QQ, How does this work if I am already using cuda graphs? Does torch.compile with mode=reduce-overheads works?

I noticed if I do torch.compile with reduce-oveheaad it raises errors like:

 RuntimeError: Cannot call CUDAGeneratorImpl::current_seed during CUDA graph capture. If you need this call to be captured, please file an issue. Current cudaStreamCaptureStatus: cudaStreamCaptureStatusActive

But running with default seems to be working.