databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

Skip updating load balancing loss on eval #69

Closed sedrick-keh-tri closed 6 months ago

sedrick-keh-tri commented 6 months ago

Currently the model updates the load balancing loss during every forward pass.

So if we have a pipeline like that does validation every X batches, it's going to lead to an assert error like Expected 6 token_per_experts but found 54

This PR fixes it so that it only updates the load balancing loss during forward passes of model.train() and not model.eval()

tgale96 commented 6 months ago

Thanks for the contribution!! @mvpatel2000 would you mind taking a look?

tgale96 commented 6 months ago

Ah, this looks fine to me :) Thanks again!