xrsrke / pipegoose

Large scale 4D parallelism pre-training for 🤗 transformers in Mixture of Experts *(still work in progress)*
MIT License
76 stars 17 forks source link

Add expert loss function #43

Closed danielgrittner closed 9 months ago

danielgrittner commented 9 months ago

The following is a sketch for using ExpertLoss:

loss_func = ExpertLoss(CrossEntropyLoss(), aux_weight=0.1, z_weight=0.1)

parallel_context = init_parallel_context(...)
 model = ExpertParallel(
      model,
      NUM_EXPERTS,
      mapping=mapping,
      router=router,
      parallel_context=parallel_context,
      expert_context=loss_func.expert_context   # PASS EXPERT CONTEXT HERE
  ).parallelize()
  optim = Adam(model.parameters(), lr=1e-3)

outputs = model(**kwargs["input"])

loss = loss_func(outputs.logits, labels)

optim.zero_grad()
loss.backward()
optim.step()

@xrsrke Let me know what you think about this design.