Closed SpeeeedLee closed 5 days ago
The fine-tuning script has not yet been integrated into this FusionBench. A PyTorch lightning module for fine-tuning is as below:
class Seq2SeqLMModule(pl.LightningModule):
def __init__(
self,
model: AutoModelForSeq2SeqLM | peft.PeftModel,
tokenizer: AutoTokenizer,
optim_cfg: DictConfig,
):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.optim_cfg = optim_cfg
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def configure_optimizers(self):
"""
Configure the optimizer and learning rate scheduler.
Returns:
Dict: Dictionary containing the optimizer and learning rate scheduler.
"""
optim = {}
if "optimizer" in self.optim_cfg:
optim["optimizer"]: torch.optim.Optimizer = instantiate(
self.optim_cfg["optimizer"],
params=self.parameters(),
)
if "lr_scheduler" in self.optim_cfg:
optim["lr_scheduler"]: torch.optim.lr_scheduler.LRScheduler = instantiate(
self.optim_cfg["lr_scheduler"],
optimizer=optim["optimizer"],
)
if self.trainer.is_global_zero:
log.info(f"{'configure_optimizers':=^50}")
log.info(optim)
return optim
def training_step(self, batch, batch_idx: int):
outputs = self.forward(**batch)
loss = outputs.loss
self.log("train/loss", loss)
return loss
def save_trainable_parameters(self):
if self.logger.log_dir is not None:
# save trainable parameters
ckpt_path = (
Path(self.trainer.log_dir)
/ "checkpoints"
/ f"epoch={self.current_epoch}_step={self.global_step}.pth"
)
if not ckpt_path.parent.exists():
Path.mkdir(ckpt_path.parent, exist_ok=True)
state_dict = dict(
(k, p) for k, p in self.model.named_parameters() if p.requires_grad
)
torch.save(state_dict, ckpt_path)
def on_train_epoch_end(self) -> None:
self.save_trainable_parameters()
We provide some fine-tuned Flan models on HuggingFace, which were fine-tuned using Adam optimizer, learning rate 1e-5 and 2e-5 for full fine-tuning, 3e-5 and 4e-5 for LoRA fine-tuning (and select the model with better validation performance), and weight decay set to 0, 2000 steps training for each model.
Thank you for your reply. How about the batch size?
the batch size was set to 16.
Hi, great benchmark! I could not find the training config of the Flan model in your paper, (learning rate, and epoch for example). Would it be possible to guide me through those hyperparameters?
Also, did you implement any regularizer while finetuning? Thanks!