huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.19k stars 26.33k forks source link

can't resume lora training due to wandb logging num params #33320

Closed mdabbah-deci closed 16 hours ago

mdabbah-deci commented 1 week ago

System Info

Hi, I have some trained checkpoints that i'd like to resume from all of them are lora checkpoints but when resuming i get the following error in trainer

 trainer.train(resume_from_checkpoint=script_args.resume_from_checkpoint)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2202, in _inner_training_loop
    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 460, in on_train_begin
    return self.call_event("on_train_begin", args, state, control)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 507, in call_event
    result = getattr(callback, event)(
  File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 900, in on_train_begin
    self.setup(args, state, model, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 853, in setup
    self._wandb.config["model/num_parameters"] = model.num_parameters()
  File "/root/.local/lib/python3.10/site-packages/wandb/sdk/wandb_config.py", line 149, in __setitem__
    key, val = self._sanitize(key, val)
  File "/root/.local/lib/python3.10/site-packages/wandb/sdk/wandb_config.py", line 258, in _sanitize
    raise config_util.ConfigError(
wandb.sdk.lib.config_util.ConfigError: Attempted to change value of key "model/num_parameters" from 0 to 266240

I assume that the fact that this is a lora training is relevant because the error describes a change in number or params (which shouldn't be logged as 0 from the first place)

and even though in line integration_util.py#L838

the wandb config dict was set with allow_val_change=True i still get the above error. in line /integration_utils.py#L853

any idea on how to solve this?

Thanks

Who can help?

@muellerzr @SunMarc

Information

Tasks

Reproduction

step 1. train a small model with dpo lora step 2. try to resume with trainer.train(resume_from_checkpoint=True) while setting

 os.environ["WANDB_RESUME"] = "allow"
os.environ["WANDB_RUN_ID"] = script_args.run_id # same run_id as previous run
DPOConfig(
    output_dir=script_args.output_dir # same previous run out dir
    run_name=script_args.run_name # same run_name as previous run 
..
)

...

trainer.train(resume_from_checkpoint=True)

Expected behavior

being able to resume previous training

ZIYU-DEEP commented 5 days ago

submitted a pr on this: #33464