Open Phoenix-Shen opened 2 weeks ago
Thanks @Phoenix-Shen! Let me cc @yikangshen, who has contributed the model.
Hi @Phoenix-Shen, thanks for bringing up the issue! Your fix looks good to me. Would you like to submit a PR?
Hi @Phoenix-Shen, thanks for bringing up the issue! Your fix looks good to me. Would you like to submit a PR?
Ok, I've fixed all the bugs and am ready to submit a PR.
Thanks, reviewed!
System Info
transformers version: 4.43.0.dev0 (installed from source)
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Outline: There are a couple of bugs that cause JetMoE to not be able to output logits for gating and calculate aux_loss.
BASE_DIR = "model_ckpt"
from jetmoe import JetMoEForCausalLM, JetMoEConfig, JetMoEForSequenceClassification
model_name = os.path.join(BASE_DIR, "jetmoe-8b") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto" )
output = model.forward( torch.zeros(32, 12, device="cuda", dtype=torch.long), output_router_logits=True, return_dict=True, )
Expected behavior
The solution has been described in the previous section.