huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.22k stars 1.16k forks source link

DPO Evalution with WandB triggers a `cannot pickle '_thread.lock' object` failure #1849

Closed fozziethebeat closed 1 month ago

fozziethebeat commented 1 month ago

This is a re-occurrence of #914.

Copy pasting my comments so they're here too:

I think this problem has resurfaced at some point. I'm running TRL indirectly through Axolotl and I'm seeing this line triggering the ProgressCallback.

Then pretty naturally when that callback does logs = copy.deepcopy(logs) the WandB table in the logs breaks things with the same failure.

The key parts of my stacktrace are:

  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 2221, in train
    return inner_training_loop(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 2728, in _inner_training_loop
    self._maybe_log_save_evaluate(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 3299, in _maybe_log_save_evaluate
    metrics = self._evaluate(trial, ignore_keys_for_eval)
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 3240, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 4277, in evaluate
    output = eval_loop(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1814, in evaluation_loop
    self.log(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1856, in log
    return super().log(logs)
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 3802, in log
    self.control = self.callback_handler.on_log(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer_callback.py", line 628, in on_log
    return self.call_event("on_log", args, state, control, logs=logs)
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer_callback.py", line 641, in call_event
    result = getattr(callback, event)(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer_callback.py", line 775, in on_log
    logs = copy.deepcopy(logs)
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)

Important versions of libraries are:

transformers==4.42.3
trl==0.9.6
wandb==0.17.4

The only solution i've found is to drop wandb logging or deleting the log line in the DPO Trainer.

I feel like the right fix is that when the DPO trainer calls self.log(...) it should not trigger the ProgressCallback and its kinda weird that it is

fozziethebeat commented 1 month ago

My debugging tells me that the WandbCallback is doing something weird to the table such that you can't pickle the table after it runs. I hacked this code locally to do a deepcopy after the wandb.log and it broke just like this.

I'm guessing wandb did something such that the log event prevents pickling?

fozziethebeat commented 1 month ago

I've extracted the core flow that the DPO log statement is triggering and the combination seems not feasible:

import wandb
import copy
wandb.login()
run = wandb.init(
    project="wandb-debug",
)
# make a table
table = wandb.Table(
    columns=["Prompt", "Policy", "Ref Model"],
    rows=[
        ["prompt", "other prompt", "more prompt"],
    ],
)
# prove copying before logging works
copy.deepcopy({
    "game_log": table,
})
# log the table
wandb.log({"table": table})
# pickle failure here
copy.deepcopy({
    "game_log": table,
})

I tried this with wandb with a sample of versions from 0.14.0 up to 0.17.4 and they all triggered this flow so I'm assuming its a feature of WandB.

Ultimately I think the wandb callback should probably remove the table before other tables try to do a deep copy

skylooop commented 1 month ago

Encountered same problem when doing DPO with generate_during_eval=True

fozziethebeat commented 1 month ago

Yes, this flow only triggers when using that setting.

skylooop commented 1 month ago

Yes, this flow only triggers when using that setting.

Did you manage to find some solution or how to perform generation on evaluation? I tried various options with direct logging into wandb with custom callback, however my process hangs with usage of DeepSpeed when logging only on main process.

fozziethebeat commented 1 month ago

I haven't found a good solution yet.

fozziethebeat commented 1 month ago

The problem came about due to this commit where they introduced a deepcopy. From the release history it should be present in transformers versions from https://github.com/huggingface/transformers/releases/tag/v4.40.0 onwards.

fozziethebeat commented 1 month ago

As noted above, i've reported this to Transformers with my proposed fix: don't pickle things with a deep copy.

fozziethebeat commented 1 month ago

Transformers is now fixed after this commit. Managed to figure out where Transformers accepted a PR that did a DeepCopy of WandB tables.

skylooop commented 1 month ago

@fozziethebeat Sorry for bothering you, but maybe you observed the following: during generate_during_eval=True ref_model response is similiar to the currently trained policy model. So seems like ref_model is being updated, which should not be the case for DPO. This issue is very similiar

fozziethebeat commented 1 month ago

Yeah that's the code path that triggers this. But it's fixed now with the latest version of transformers