huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.49k stars 25.7k forks source link

Some Bugs in JetMoE #31791

Open Phoenix-Shen opened 2 weeks ago

Phoenix-Shen commented 2 weeks ago

System Info

transformers version: 4.43.0.dev0 (installed from source)

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

Outline: There are a couple of bugs that cause JetMoE to not be able to output logits for gating and calculate aux_loss.

  1. Code I want to output the logits of the gating.
    
    from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    AutoModelForSequenceClassification,
    )
    import os
    import torch

BASE_DIR = "model_ckpt"

from jetmoe import JetMoEForCausalLM, JetMoEConfig, JetMoEForSequenceClassification

model_name = os.path.join(BASE_DIR, "jetmoe-8b") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto" )

output = model.forward( torch.zeros(32, 12, device="cuda", dtype=torch.long), output_router_logits=True, return_dict=True, )


3. It will report an error:
Traceback (most recent call last):
  File "/home/ubuntu/ssk/test_jetmoe.py", line 18, in <module>
    output = model.forward(
  File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/transformers/models/jetmoe/modeling_jetmoe.py", line 1365, in forward
    self.num_experts,
  File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1709, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'JetMoeForCausalLM' object has no attribute 'num_experts'

4. Analysis
After examination of the code (https://github.com/huggingface/transformers/blob/main/src/transformers/models/jetmoe/modeling_jetmoe.py), I found serval mistakes:
- `self.num_experts` and `self.num_experts_per_tok` are not defined in the `JetMoeForCausalLM` class.
- the code does not pass `output_router_logits` argument to the forward function of `self.model` in `JetMoeForCausalLM` class. (see line 1310 and 1341, modeling_jetmoe.py)
- for the `JetMoeForSequenceClassification` class, it misses the process of calculating aux_loss and forgets to pass `output_router_logits` argument to `self.model.forward`.

5. Quick fix of the `JetMoeForCausalLM` class
- Add `self.num_experts = config.num_local_experts`, and `self.num_experts_per_tok = config.num_experts_per_tok` in the `__init__` function of the `JetMoeForCausalLM`.
- Pass `output_router_logits` to `self.model.forward` (line 1331)
  ```python
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            output_router_logits=output_router_logits # Add this line.
        )

Expected behavior

The solution has been described in the previous section.

LysandreJik commented 2 weeks ago

Thanks @Phoenix-Shen! Let me cc @yikangshen, who has contributed the model.

yikangshen commented 1 week ago

Hi @Phoenix-Shen, thanks for bringing up the issue! Your fix looks good to me. Would you like to submit a PR?

Phoenix-Shen commented 1 week ago

Hi @Phoenix-Shen, thanks for bringing up the issue! Your fix looks good to me. Would you like to submit a PR?

Ok, I've fixed all the bugs and am ready to submit a PR.

ArthurZucker commented 1 week ago

Thanks, reviewed!