Enhance compute_loss Method for Dynamic Logit Generation
This PR enhances the compute_loss method in the DAMTrainer class to support dynamic generation of logits on the fly. The key changes include:
Dynamic Logit Generation: Introduced a new parameter generate_logits_on_fly to control whether logits should be generated dynamically during loss computation.
Conditional Logic: Added conditional logic to handle both scenarios:
If generate_logits_on_fly is True, the method generates logits for each model in merged_model.num_models and computes the individual logit losses.
If generate_logits_on_fly is False, it uses the precomputed topk_logits and topk_indices from the inputs.
These changes improve the flexibility of the compute_loss method, allowing it to adapt to different use cases and optimize performance based on the specific requirements of the training process.
Enhance
compute_loss
Method for Dynamic Logit GenerationThis PR enhances the
compute_loss
method in theDAMTrainer
class to support dynamic generation of logits on the fly. The key changes include:generate_logits_on_fly
to control whether logits should be generated dynamically during loss computation.generate_logits_on_fly
isTrue
, the method generates logits for each model inmerged_model.num_models
and computes the individual logit losses.generate_logits_on_fly
isFalse
, it uses the precomputedtopk_logits
andtopk_indices
from theinputs
.These changes improve the flexibility of the
compute_loss
method, allowing it to adapt to different use cases and optimize performance based on the specific requirements of the training process.