Currently the model updates the load balancing loss during every forward pass.
So if we have a pipeline like that does validation every X batches, it's going to lead to an assert error like Expected 6 token_per_experts but found 54
This PR fixes it so that it only updates the load balancing loss during forward passes of model.train() and not model.eval()
Currently the model updates the load balancing loss during every forward pass.
So if we have a pipeline like that does validation every X batches, it's going to lead to an assert error like
Expected 6 token_per_experts but found 54
This PR fixes it so that it only updates the load balancing loss during forward passes of model.train() and not model.eval()