foundation-model-stack / fms-hf-tuning

🚀 Collection of tuning recipes with HuggingFace SFTTrainer and PyTorch FSDP.
Apache License 2.0
9 stars 30 forks source link

Fix additional callbacks #199

Closed VassilisVassiliadis closed 5 days ago

VassilisVassiliadis commented 2 weeks ago

Description of the change

Fixes a small bug in the method of train() sft_trainer.py where trainer_callbacks would append an array of user provided callbacks instead of concatenating the 2 arrays.

Without this change you get this exception:

self = <transformers.trainer_callback.CallbackHandler object at 0x32b558fa0>
event = 'on_init_end'
args = TrainingArguments(output_dir='/var/folders/p_/lhp3_gn503l7djn80tdzkf3h0000gn/T/tmpu2qek8_i', overwrite_output_dir=Fals...t_modules=None, batch_eval_metrics=False, cache_dir=None, max_seq_length=4096, packing=False, trackers=['file_logger'])
state = TrainerState(epoch=None, global_step=0, max_steps=0, logging_steps=500, eval_steps=500, save_steps=500, train_batch_si..., 'should_epoch_stop': False, 'should_save': False, 'should_evaluate': False, 'should_log': False}, 'attributes': {}}})
control = TrainerControl(should_training_stop=False, should_epoch_stop=False, should_save=False, should_evaluate=False, should_log=False)
kwargs = {}
callback = [<transformers.trainer_callback.TrainerCallback object at 0x30f3508e0>]
result = None

    def call_event(self, event, args, state, control, **kwargs):
        for callback in self.callbacks:
>           result = getattr(callback, event)(
                args,
                state,
                control,
                model=self.model,
                tokenizer=self.tokenizer,
                optimizer=self.optimizer,
                lr_scheduler=self.lr_scheduler,
                train_dataloader=self.train_dataloader,
                eval_dataloader=self.eval_dataloader,
                **kwargs,
            )
E           AttributeError: 'list' object has no attribute 'on_init_end'

../venv/lib/python3.10/site-packages/transformers/trainer_callback.py:498: AttributeError

Related issue number

Contributes to #142

How to verify the PR

Invoke the train() method with an additional_callbacks parameter that's an array containing at least 1 callback.

Was the PR tested

michael-johnston commented 1 week ago

@Ssukriti @anhuong This fix would really help the benchmarking effort as it removes the need for us to maintain a fork to work around it (and hence falling out of synch). Any help in getting it merged appreciated. :-)

dushyantbehl commented 1 week ago

LGTM. Thanks @VassilisVassiliadis I missed this check in my unit testing PR.

VassilisVassiliadis commented 1 week ago

No problem!

michael-johnston commented 5 days ago

@anhuong @Ssukriti @alex-jw-brooks Just a bump on this as we'd like to get rid of the need for a fork ASAP and this is just a single line change :-)

VassilisVassiliadis commented 5 days ago

no worries!