huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.19k stars 26.6k forks source link

save_pretrained is changing the name of module when saving #33680

Open ZhiyuanChen opened 2 weeks ago

ZhiyuanChen commented 2 weeks ago

System Info

Who can help?

No response

Information

Tasks

Reproduction


class XxxSparseLayer(nn.Module):
    def __init__(self, config: XxxConfig):
        super().__init__()
        self.num_experts = config.num_experts
        self.top_k = config.expert_top_k
        self.attention = XxxAttention(config)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.router = nn.Linear(config.hidden_size, self.num_experts)
        self.experts = nn.ModuleList([XxxFeedForward(config) for _ in range(self.num_experts)])
        self._register_load_state_dict_pre_hook(self.copy_ffn_params)

    def forward(
        self,
        hidden_states: Tensor,
        attention_mask: torch.FloatTensor | None = None,
        head_mask: torch.FloatTensor | None = None,
        output_attentions: bool = False,
    ) -> Tuple[Tensor, ...]:
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        layer_output = self.layer_norm(attention_output)

        router_logits = self.router(layer_output[:, 0, :])
        router_probs = router_logits.softmax(dim=-1)
        router_weights, router_idx = router_probs.topk(self.top_k, dim=-1)
        router_weights /= router_weights.sum(dim=-1, keepdim=True)

        expert_outputs = torch.stack([self.experts[i](layer_output) for i in range(self.num_experts)], dim=1)
        solicited_outputs = expert_outputs[torch.arange(router_idx.size(0)).unsqueeze(1), router_idx]
        weighted_outputs = (solicited_outputs * router_weights.unsqueeze(-1).unsqueeze(-1)).sum(1)

        layer_output = weighted_outputs + layer_output
        return (layer_output,) + self_attention_outputs[1:]

    def copy_ffn_params(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        ffn_prefix = prefix + "ffn."
        ffn_states = {k: v for k, v in state_dict.items() if k.startswith(ffn_prefix)}
        ffn_states = {k: v for k, v in state_dict.items() if "layer_norm" not in k}
        for k, v in ffn_states.items():
            for i in range(self.num_experts):
                state_dict[k.replace("ffn.", f"experts.{i}.")] = v.clone()
            del state_dict[k]

class XxxFeedForward(nn.Module):
    def __init__(self, config: XxxConfig):
        super().__init__()
        self.in_proj = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.act = ACT2FN[config.hidden_act]
        else:
            self.act = config.hidden_act
        self.out_proj = nn.Linear(config.intermediate_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout)

    def forward(self, hidden_states: Tensor) -> Tensor:
        hidden_states = self.in_proj(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states
>>> model.save_pretrained('xxx')
>>> model.state_dict() == torch.load('xxx/pytorch_model.bin')
False

Expected behavior

I have the above layer definition.

Since it's a MoE module all experts shares one layer_norm, the layer norm of FFN is in Layer, not FFN.

But when using the save_pretrained, transformers will move the weights to ffn automatically, causing the load to fail.

LysandreJik commented 2 weeks ago

cc @ArthurZucker on MoEs

ArthurZucker commented 1 week ago

Hey! Sorry but without the config class I can't help 😓 would you mind providing the full repro?

ZhiyuanChen commented 1 week ago

Hey! Sorry but without the config class I can't help 😓 would you mind providing the full repro?

It should be reproducible for any config class, like ESM Im not allowed to share any thing at this stage as it's under review.

ArthurZucker commented 1 week ago

Ah sorry about that 😢 In general maybe removing safe_serialization will prevent weight removal. You should also fill the tie_weight_keys as basically these tensors share the same memory!