arcee-ai / DAM

30 stars 4 forks source link

Seamless Switch Between On-the-Fly and Pre-Computed Logits #19

Closed shamanez closed 1 month ago

shamanez commented 1 month ago

Seamless Switch Between On-the-Fly and Pre-Computed Logits

Description:

This PR introduces a significant enhancement to the DAMTrainer class, allowing for a seamless switch between on-the-fly and pre-computed logits. This flexibility is particularly beneficial for users with GPUs, enabling faster operations and more efficient testing experiments.

Key Changes:

  1. New Parameters:

    • generate_logits_on_fly: A boolean parameter to control whether logits should be generated on-the-fly or pre-computed.
    • use_all_logits: A boolean parameter to indicate if all logits should be used. This is only applicable when generate_logits_on_fly is True.
  2. Assertions:

    • Added an assertion to ensure that use_all_logits cannot be True if generate_logits_on_fly is False.
  3. Logits Computation:

    • When generate_logits_on_fly is True, logits for each individual model are computed dynamically during training.
    • When generate_logits_on_fly is False, pre-computed logits are used, and the top-K logits are gathered using the provided indices.
  4. Efficiency Improvements:

    • By allowing on-the-fly logits generation, users with GPUs can leverage their hardware to perform operations faster.
    • This flexibility also aids in conducting various testing experiments more efficiently.
  5. Code Updates:

    • Updated the compute_loss method to handle both on-the-fly and pre-computed logits.
    • Modified the compute_individual_logit_losses method to accommodate the new parameters and logic.

Benefits:

Usage: To use the new functionality, simply set the generate_logits_on_fly and use_all_logits parameters when initializing the DAMTrainer:

trainer = DAMTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    data_collator=default_data_collator,
    lambda_coef=lambda_coef,
    lambda_coef_l1=lambda_coef_l1,
    lambda_coef_l2=lambda_coef_l2,
    temperature=temperature,
    use_kl=use_kl,
    use_mse=use_mse,
    use_entropy=use_entropy,
    base_model_path=base_model_name,
    generate_logits_on_fly=True,  # Enable on-the-fly logits generation
    use_all_logits=True,  # Use all logits when generating on-the-fly
)

This PR enhances the DAMTrainer class, making it more versatile and efficient for various training and testing scenarios.