databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

save loading_balancing_loss properly #82

Closed gouchangjiang closed 5 months ago

gouchangjiang commented 6 months ago

Hi, in the forward function of ParallelMLP, should we save directly the load_balancing_loss or a tuple of tokens_per_expert and scores? In other words, should line 428, save_load_balancing_loss((tokens_per_expert, scores)), be replaced by save_load_balancing_loss(self.load_balancing_loss(tokens_per_expert, scores))?

gouchangjiang commented 6 months ago

By the way, should we use balancing_loss when using megablocks?

tgale96 commented 6 months ago

Hi! We save the data required to compute the load balancing loss, rather than the load balancing loss itself, so that we can compute the LBL for all layers at once using batched_load_balaning_loss.

And yes, we highly recommending using load balancing losses training MoEs!