huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
16.55k stars 1.64k forks source link

LoraConfig not JSON serializable for logging to wandb #2239

Open v-bosch opened 12 hours ago

v-bosch commented 12 hours ago

System Info

Hi all, I am running a custom transformer (llava-style) based on Llama2 using PEFT, QLORA and FSDP. It is runnable, but I get a strange error coming from weights & biases, where it seems to be trying to log the LoraConfig, which as a ListConfig is not json serializable. If someone has any info, I would be very grateful! I have not found this error mentioned anywhere else. Please let me know if I can provide more information. I unfortunately cannot share the code yet (+ it is huge). Just hoping to find some pointers to where to start debugging this...

versions: wandb 0.18 python 3.10 peft 0.12.0 transformers 4.46.3 trl 0.12.1

Who can help?

No response

Information

Tasks

Reproduction

Here are a few bits of code.

model = llava_LM(cfg) # This is my own model class

lora_config = LoraConfig(r=cfg.lora_r, lora_alpha=cfg.lora_alpha, lora_dropout=cfg.lora_dropout, 
            target_modules=cfg.target_modules, bias="none", task_type="CAUSAL_LM")
model = get_peft_model(model, lora_config)

args = SFTConfig(
        output_dir=SAVE_PATH,
        remove_unused_columns=False,
        seed=cfg.seed,
        num_train_epochs=cfg.n_pretrain_epochs,
        gradient_accumulation_steps=len(cfg.subjects), 
        learning_rate=cfg.pretrain_lr,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        logging_steps = len(cfg.subjects)*100, # because of gradient acc. 
        optim = cfg.optimizer,
        fp16=False, bf16=True, 
        evaluation_strategy="no", #"epoch", 
        label_names=['captions'], # eval does not function properly for decoder-only HF models
        report_to='wandb',
        save_total_limit = 1,
        save_strategy = "steps",
        save_steps = 20, 
        warmup_ratio = 0.03,
        lr_scheduler_type = "cosine",
        weight_decay=cfg.l2,
        prediction_loss_only=True,
        dataset_kwargs = {"skip_prepare_dataset": True},
      )

    trainer = SFTTrainer(model=model,
        tokenizer=model.llm_tokenizer,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        # peft_config=lora_config, # not necessary?
        data_collator=new_collate_fn(train_dataset),
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )

    trainer.train()

The error I get is the following:

Error executing job with overrides: []
Traceback (most recent call last):
  File "/home/staff/v/vbosch/Brformer/scripts/torch_scripts/train_llava.py", line 425, in main
    run_model(model_id, train_dataset, valid_dataset, glasser_vertices_idx, cfg)
  File "/home/staff/v/vbosch/Brformer/scripts/torch_scripts/train_llava.py", line 379, in run_model
    pretrain(cfg, model, lora_config, train_dataset, valid_dataset,
  File "/home/staff/v/vbosch/Brformer/scripts/torch_scripts/train_llava.py", line 355, in pretrain
    trainer.train()
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/transformers/trainer.py", line 2382, in _inner_training_loop
    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/transformers/trainer_callback.py", line 468, in on_train_begin
    return self.call_event("on_train_begin", args, state, control)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/transformers/trainer_callback.py", line 518, in call_event
    result = getattr(callback, event)(
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 911, in on_train_begin
    self.setup(args, state, model, **kwargs)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 843, in setup
    self._wandb.config.update(combined_dict, allow_val_change=True)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/wandb/sdk/wandb_config.py", line 189, in update
    self._callback(data=sanitized)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/wandb/sdk/wandb_run.py", line 403, in wrapper_fn
    return func(self, *args, **kwargs)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/wandb/sdk/wandb_run.py", line 1396, in _config_callback
    self._backend.interface.publish_config(key=key, val=val, data=data)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/wandb/sdk/interface/interface.py", line 185, in publish_config
    cfg = self._make_config(data=data, key=key, val=val)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/wandb/sdk/interface/interface.py", line 139, in _make_config
    update.value_json = json_dumps_safer(json_friendly(v)[0])
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/wandb/util.py", line 831, in json_dumps_safer
    return dumps(obj, cls=WandBJSONEncoder, **kwargs)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/json/__init__.py", line 238, in dumps
    **kw).encode(obj)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/json/encoder.py", line 199, in encode
    chunks = self.iterencode(o, _one_shot=True)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/json/encoder.py", line 257, in iterencode
    return _iterencode(o, 0)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/wandb/util.py", line 782, in default
    return json.JSONEncoder.default(self, obj)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type ListConfig is not JSON serializable

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
wandb: 🚀 View run Llama_QA_LORA_11-27_15-00-16_pretrain_PCA95_subj12_fsdp_test_visual at: https://wandb.ai/vbosch/CorText_llama/runs/3bafvb8b
wandb: Find logs at: wandb/run-20241127_150016-3bafvb8b/logs
[2024-11-27 15:03:55,934] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 953481 closing signal SIGTERM
[2024-11-27 15:03:56,299] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 953480) of binary: /home/staff/v/vbosch/.conda/envs/BQA_n/bin/python
Traceback (most recent call last):
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1161, in launch_command
    multi_gpu_launcher(args)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/accelerate/commands/launch.py", line 799, in multi_gpu_launcher
    distrib_run.run(args)
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

Checking what is in the config that wandb is trying to serialize, I find the following, this seems to be the lora_config (I cut off some of the config because its so large, this should be sufficient info I hope).

> /home/staff/v/vbosch/.conda/envs/BQA_n/lib/python3.10/site-packages/wandb/sdk/interface/interface.py(186)publish_config()
-> cfg = self._make_config(data=data, key=key, val=val)
(Pdb) data
{'peft_config': {'default': {'peft_type': <PeftType.LORA: 'LORA'>, 'auto_mapping': None, 'base_model_name_or_path': None, 'revision': None, 'task_type': 'CAUSAL_LM', 'inference_mode': False, 'r': 32, 'target_modules': ['q_proj', 'v_proj'], 'lora_alpha': 16, 'lora_dropout': 0.05, 'fan_in_fan_out': False, 'bias': 'none', 'use_rslora': False, 'modules_to_save': None, 'init_lora_weights': True, 'layers_to_transform': None, 'layers_pattern': None, 'rank_pattern': {}, 'alpha_pattern': {}, 'megatron_config': None, 'megatron_core': 'megatron.core', 'loftq_config': {}, 'use_dora': False, 'layer_replication': None, 'runtime_config': {'ephemeral_gpu_offload': False}}}, 'vocab_size': 32001, 'max_position_embeddings': 4096, 'hidden_size': 4096, 'intermediate_size':   ... etc. 

Expected behavior

For Wandb to not log the peftconfig or for the config to be json serializable.

BenjaminBossan commented 11 hours ago

Thanks for reporting this issue. I could not exactly reproduce the error, as calling

from wandb.util import WandBJSONEncoder
from wandb.sdk.lib.json_util import dumps
from peft import LoraConfig

dumps(LoraConfig(target_modules=["foo", "bar"]), cls=WandBJSONEncoder)

results in

TypeError: Object of type LoraConfig is not JSON serializable

and

dumps(LoraConfig(target_modules=["foo", "bar"]).to_dict(), cls=WandBJSONEncoder)

works successfully. ListConfig is not a class defined in PEFT and AFAICT it's not in wandb either.

Checking what is in the config that wandb is trying to serialize, I find the following, this seems to be the lora_config

Honestly, nothing in there strikes me as something that can't be serialized.

A similar issue as yours was reported in #567 and a user there mentioned that it's related to the use of hydra (which seems to have a ListConfig class). As your log contains HYDRA_FULL_ERROR, I assume you also use it. If so, could you please investigate if that could be the source of the error? If it's indeed related to hydra, perhaps you can open an issue on wandb and ask them if they can add special handling for hydra configs.