eric-mitchell / direct-preference-optimization

Reference implementation for DPO (Direct Preference Optimization)
Apache License 2.0
2.18k stars 180 forks source link

Unable to run the code for Step 2: Run SFT #59

Closed ppsmk388 closed 10 months ago

ppsmk388 commented 11 months ago

When I run the Run SFT script in the COMPLETE example:

python -u train.py model=pythia28 datasets=[hh] loss=sft exp_name=anthropic_dpo_pythia28 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16

I encountered the following error:

Error executing job with overrides: ['model=pythia28', 'datasets=[hh]', 'loss=sft', 'exp_name=anthropic_dpo_pythia28', 'gradient_accumulation_steps=16', 'batch_size=64', 'ev al_batch_size=32', 'trainer=FSDPTrainer', 'sample_during_eval=false', 'model.fsdp_policy_mp=bfloat16'] Traceback (most recent call last): File "train.py", line 111, in main mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True) File "/data/huzhengyu/anaconda3/envs/ddd/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn return start_processes(fn, args, nprocs, join, daemon, start_method='spawn') File "/data/huzhengyu/anaconda3/envs/ddd/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes while not context.join(): File "/data/huzhengyu/anaconda3/envs/ddd/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join raise ProcessRaisedException(msg, error_index, failed_process.pid) torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error: Traceback (most recent call last): File "/data/huzhengyu/anaconda3/envs/ddd/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap fn(i, *args) File "/home/huzhengyu/direct-preference-optimization/train.py", line 32, in worker_main wandb.init( File "/data/huzhengyu/anaconda3/envs/ddd/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 1169, in init raise e File "/data/huzhengyu/anaconda3/envs/ddd/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 1146, in init wi.setup(kwargs) File "/data/huzhengyu/anaconda3/envs/ddd/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 289, in setup wandb_login._login( File "/data/huzhengyu/anaconda3/envs/ddd/lib/python3.8/site-packages/wandb/sdk/wandb_login.py", line 298, in _login wlogin.prompt_api_key() File "/data/huzhengyu/anaconda3/envs/ddd/lib/python3.8/site-packages/wandb/sdk/wandb_login.py", line 228, in prompt_api_key raise UsageError("api_key not configured (no-tty). call " + directive) wandb.errors.UsageError: api_key not configured (no-tty). call wandb.login(key=[your_api_key])

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. /data/huzhengyu/anaconda3/envs/ddd/lib/python3.8/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker : There appear to be 1 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d '

eric-mitchell commented 10 months ago

Looks like your issue is here:

wandb.errors.UsageError: api_key not configured (no-tty). call wandb.login(key=[your_api_key])

You should make sure you've logged into wandb by running wandb login from the terminal, or disable wandb logging by passing wandb.enabled=False to your run command.