from transformers import TrainerCallback
class LossLoggingCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is not None:
loss = logs.get("loss")
if loss is not None:
print(f"Step {state.global_step}: Loss: {loss}")
# Optionally, you can store the loss in a file or a list for further processing
with open("training_loss_log.txt", "a") as log_file:
log_file.write(f"Step {state.global_step}: Loss: {loss}\n")
Then add the callback to the trainer
from transformers import TrainingArguments
training_args = TrainingArguments(
self.experiment_dir,
num_train_epochs=self.epochs,
per_device_train_batch_size=self.batch_size,
save_strategy="no",
**self.train_hyperparameters
)
# Add the custom callback to the trainer
loss_logging_callback = LossLoggingCallback()
# Create the trainer with the callback
tabula_trainer = TabulaTrainer(
self.model,
training_args,
train_dataset=tabula_ds,
tokenizer=self.tokenizer,
data_collator=TabulaDataCollator(self.tokenizer),
callbacks=[loss_logging_callback] # Add the callback here
)
You can implement and if it is bug-free, you can create a PR and I will merge it.
Hi @SunriseEastSea
Here is how you can do that.
First define a callback function
Then add the callback to the trainer
You can implement and if it is bug-free, you can create a PR and I will merge it.