huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.04k stars 26.3k forks source link

dMoE support #27915

Open AlpinDale opened 9 months ago

AlpinDale commented 9 months ago

Feature request

MistralAI recently released their new model, a Mixture of Experts based on megablocks, a type of dropless Mixture of Experts.

Motivation

It's very likely that the future of open source LLMs will be MoEs. Having it in HF transformers would allow us to use the built-in trainer, as it's unwieldy to use Megatron-LM for the average user who's only ever done QLoRA.

Your contribution

No clue for now.

kevinhu commented 9 months ago

I've been using a quick drop-in replacement lifted from the dmoe.py implementation from megablocks.

Running torch.cat for the expert weights on each forward pass adds a ~5% overhead, since I didn't want to deal with managing the state dicts. Overall training is 2-3x faster.

class MixtralSparseMoeBlock(torch.nn.Module):
    def __init__(self, config: MixtralConfig):
        super(MixtralSparseMoeBlock, self).__init__()

        self.config = config

        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok

        self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
        self.experts = nn.ModuleList(
            [MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)]
        )

        self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
        self.blocking = 128
        self.quantize_scatter_num_bits = -1
        max_column_index = (self.ffn_dim * self.num_experts) // self.blocking
        self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1)

    # From https://github.com/stanford-futuredata/megablocks/blob/7c25169ce87c32c31e8845ef34785d3095b1a2cb/megablocks/layers/dmoe.py#L31
    def sparse_transpose(self, size, row_indices, column_indices):
        block_columns = size[1] // self.blocking

        # Sort row indices by column indices to get the transposed matrix's
        # column indices.
        #
        # NOTE: Our sort operation uses the same width indices as the input values.
        # To avoid overflow when we have large activation matrices we cast to
        # 32-bit before sorting.
        _, gather_indices = ops.sort(column_indices.int(), self.transpose_sort_end_bit)

        # There are a constant number of blocks in every row of the sparse matrix.
        # A blocks offset is:
        #
        # row_index * blocks_per_row + column_index % blocks_per_row
        #
        # Once we have the block offsets ordered for transposition we can divide
        # by blocks_per_row to get the transposed column indices.
        column_indices_t = row_indices.gather(0, gather_indices.long())
        block_offsets_t = gather_indices.int()

        zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
        nnz_per_column = ops.histogram(column_indices, block_columns)
        nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
        offsets_t = torch.cat([zero, nnz_per_column])
        return column_indices_t, offsets_t, block_offsets_t

    # From https://github.com/stanford-futuredata/megablocks/blob/7c25169ce87c32c31e8845ef34785d3095b1a2cb/megablocks/layers/dmoe.py#L59
    def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
        padded_tokens, _ = x.size()
        assert padded_tokens % self.blocking == 0
        assert self.ffn_dim % self.blocking == 0

        # Offsets for the sparse matrix. All rows have the
        # same number of nonzero blocks dictated by the
        # dimensionality of a single expert.
        block_rows = padded_tokens // self.blocking
        blocks_per_row = self.ffn_dim // self.blocking
        offsets = torch.arange(
            0,
            block_rows * blocks_per_row + 1,
            blocks_per_row,
            dtype=torch.int32,
            device=x.device,
        )

        # Indices for the sparse matrix. The indices for
        # the intermediate matrix are dynamic depending
        # on the mapping of tokens to experts.
        column_indices = ops.topology(
            padded_bins, self.blocking, block_rows, blocks_per_row
        )

        # TODO(tgale): This is unused. Remove the need for this in stk.
        # For now, use meta init to save the device memory.
        data = torch.empty(
            column_indices.numel(),
            self.blocking,
            self.blocking,
            dtype=x.dtype,
            device="meta",
        )
        shape = (padded_tokens, self.ffn_dim * self.num_experts)
        row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
        column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
            shape, row_indices, column_indices
        )
        return stk.Matrix(
            shape,
            data,
            row_indices,
            column_indices,
            offsets,
            column_indices_t,
            offsets_t,
            block_offsets_t,
        )

    # From https://github.com/stanford-futuredata/megablocks/blob/7c25169ce87c32c31e8845ef34785d3095b1a2cb/megablocks/layers/dmoe.py#L103
    def indices_and_padded_bins(self, top_experts: torch.Tensor):
        # Sort the expert ids to produce the scatter/gather
        # indices for the permutation.
        top_experts = top_experts.int()
        bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)

        # Histogram the expert ids to identify the number of
        # tokens routed to each expert.
        tokens_per_expert = ops.histogram(top_experts, self.num_experts)

        # Round the token counts up to the block size used in
        # the matrix muliplications. Caculate the starting
        # position of each bin.
        padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking)
        padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
        padded_bins = promote_scalar(padded_bins)

        # Calculate the bin bounds for the sorted tokens.
        bins = ops.inclusive_cumsum(tokens_per_expert, 0)
        bins = promote_scalar(bins)
        return indices, bin_ids, bins, padded_bins, tokens_per_expert

    # From https://github.com/stanford-futuredata/megablocks/blob/7c25169ce87c32c31e8845ef34785d3095b1a2cb/megablocks/layers/dmoe.py#L126
    def sparse_forward(
        self,
        hidden_states: torch.Tensor,
        expert_weights: torch.Tensor,
        top_experts: torch.Tensor,
    ):
        # x: [sl, bs, hs]
        # expert_weights: [sl * bs, top-k]
        # top_experts: [sl * bs, top-k]
        expert_weights = expert_weights.flatten().to(hidden_states.dtype)
        top_experts = top_experts.flatten()

        with torch.no_grad():
            (
                indices,
                bin_ids,
                bins,
                padded_bins,
                _,
            ) = self.indices_and_padded_bins(top_experts)

        # Permute tokens and pad to prepare expert computation
        # (top_k * sequence_length  padding, model_dim)
        # Route the tokens for MoE computation.
        hidden_states = ops.padded_gather(
            hidden_states, indices, bin_ids, bins, padded_bins, self.top_k
        )

        # Create the sparse matrix topology
        with torch.no_grad():
            topo = self.topology(hidden_states, padded_bins)

        w1 = torch.cat([expert.w1.weight.T for expert in self.experts], dim=1)
        w2 = torch.cat([expert.w2.weight for expert in self.experts], dim=1).T
        w3 = torch.cat([expert.w3.weight.T for expert in self.experts], dim=1)

        # Perform the expert computation
        hidden_states = stk.Matrix(  # type: ignore
            topo.size(),
            F.silu(stk.ops.sdd(hidden_states, w1, topo).data)
            * stk.ops.sdd(hidden_states, w3, topo).data,
            topo.row_indices,
            topo.column_indices,
            topo.offsets,
            topo.column_indices_t,
            topo.offsets_t,
            topo.block_offsets_t,
        )
        hidden_states = stk.ops.dsd(hidden_states, w2)

        # Permute back and remove padding
        # (top_k * sequence_length, model_dim)
        hidden_states: torch.Tensor = ops.padded_scatter(  # type: ignore
            hidden_states,
            indices,
            bin_ids,
            expert_weights,
            bins,
            padded_bins,
            self.top_k,
            self.quantize_scatter_num_bits,
        )
        return hidden_states

    def forward(self, hidden_states: torch.Tensor):
        orig_shape = hidden_states.shape
        batch_size, sequence_length, hidden_dim = orig_shape

        hidden_states = hidden_states.view(-1, hidden_dim)

        router_logits = self.gate(hidden_states)

        routing_weights = router_logits.softmax(dim=-1).to(hidden_states.dtype)
        routing_weights, expert_indices = torch.topk(
            routing_weights, self.top_k, dim=-1
        )
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

        hidden_states = self.sparse_forward(
            hidden_states, routing_weights, expert_indices
        )

        return hidden_states.view(*orig_shape), router_logits
DayOfThePenguin commented 8 months ago

@kevinhu what do you think would be the best way to modify the state dict to avoid the .cats? Merging the individual w1, w2, and w3 tensors that are currently in a list of MixtralBlockSparseTop2MLP into w1, w3=nn.Linear(self.hidden_dim, self.ffn_dim * self.num_experts) and w2=nn.Linear(self.ffn_dim * self.num_experts, self.hidden_dim) under the MixtralSparseMoeBlock and then using .views of it for each expert? Since you're effectively just using the MixtralBLockSparseTop2MLP class as a dataclass for storing the expert weights and not actually using its forward() method.