allenai / RL4LMs

A modular RL library to fine-tune language models to human preferences
https://rl4lms.apps.allenai.org/
Apache License 2.0
2.13k stars 191 forks source link

Problem with BLEURT reward function #34

Open eublefar opened 1 year ago

eublefar commented 1 year ago

BLEURT reward function fails with TypeError: cannot pickle '_thread.RLock' object in multiprocessing environments. Probably because it can't pickle Tensorflow model to send to environment subprocess.

Tested on both local and colab environment.

Here is the full stacktrace:

│ /home/eublefar/RL4LMs/rl4lms/envs/text_generation/training_utils.py:149 in __init__       │
│                                                                                           │
│   146 │   │   self._train_eval_config = train_eval_config                                 │
│   147 │   │   self._tracker = tracker                                                     │
│   148 │   │   self._experiment_name = experiment_name                                     │
│ ❱ 149 │   │   self._setup()                                                               │
│   150 │                                                                                   │
│   151 │   def _setup(self):                                                               │
│   152 │   │   # load trainer state from available previous checkpoint if available        │
│                                                                                           │
│ /home/eublefar/RL4LMs/rl4lms/envs/text_generation/training_utils.py:162 in _setup         │
│                                                                                           │
│   159 │   │   │   self._train_eval_config.get("metrics", []))                             │
│   160 │   │   self._samples_by_split = build_datapool(                                    │
│   161 │   │   │   self._datapool_config)                                                  │
│ ❱ 162 │   │   self._env = build_env(self._env_config, self._reward_fn,                    │
│   163 │   │   │   │   │   │   │     self._tokenizer, self._samples_by_split["train"])     │
│   164 │   │   self._alg = build_alg(self._on_policy_alg_config,                           │
│   165 │   │   │   │   │   │   │     self._env, self._tracker,                             │
│                                                                                           │
│ /home/eublefar/RL4LMs/rl4lms/envs/text_generation/training_utils.py:90 in build_env       │
│                                                                                           │
│    87 │   │   "samples": train_samples,                                                   │
│    88 │   }                                                                               │
│    89 │   env_kwargs = {**env_kwargs, **env_config.get("args", {})}                       │
│ ❱  90 │   env = make_vec_env(TextGenEnv,                                                  │
│    91 │   │   │   │   │      n_envs=env_config.get(                                       │
│    92 │   │   │   │   │   │      "n_envs", 1),                                            │
│    93 │   │   │   │   │      vec_env_cls=SubprocVecEnv,                                   │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/stable_baselines3/common/e │
│ nv_util.py:105 in make_vec_env                                                            │
│                                                                                           │
│   102 │   │   # Default: use a DummyVecEnv                                                │
│   103 │   │   vec_env_cls = DummyVecEnv                                                   │
│   104 │                                                                                   │
│ ❱ 105 │   return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_en │
│   106                                                                                     │
│   107                                                                                     │
│   108 def make_atari_env(                                                                 │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/stable_baselines3/common/v │
│ ec_env/subproc_vec_env.py:106 in __init__                                                 │
│                                                                                           │
│   103 │   │   │   args = (work_remote, remote, CloudpickleWrapper(env_fn))                │
│   104 │   │   │   # daemon=True: if the main process crashes, we should not cause things  │
│   105 │   │   │   process = ctx.Process(target=_worker, args=args, daemon=True)  # pytype │
│ ❱ 106 │   │   │   process.start()                                                         │
│   107 │   │   │   self.processes.append(process)                                          │
│   108 │   │   │   work_remote.close()                                                     │
│   109                                                                                     │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/process.py:121 in start  │
│                                                                                           │
│   118 │   │   assert not _current_process._config.get('daemon'), \                        │
│   119 │   │   │      'daemonic processes are not allowed to have children'                │
│   120 │   │   _cleanup()                                                                  │
│ ❱ 121 │   │   self._popen = self._Popen(self)                                             │
│   122 │   │   self._sentinel = self._popen.sentinel                                       │
│   123 │   │   # Avoid a refcycle if the target function holds an indirect                 │
│   124 │   │   # reference to the process object (see bpo-30775)                           │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/context.py:291 in _Popen │
│                                                                                           │
│   288 │   │   @staticmethod                                                               │
│   289 │   │   def _Popen(process_obj):                                                    │
│   290 │   │   │   from .popen_forkserver import Popen                                     │
│ ❱ 291 │   │   │   return Popen(process_obj)                                               │
│   292 │                                                                                   │
│   293 │   class ForkContext(BaseContext):                                                 │
│   294 │   │   _name = 'fork'                                                              │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/popen_forkserver.py:35   │
│ in __init__                                                                               │
│                                                                                           │
│   32 │                                                                                    │
│   33 │   def __init__(self, process_obj):                                                 │
│   34 │   │   self._fds = []                                                               │
│ ❱ 35 │   │   super().__init__(process_obj)                                                │
│   36 │                                                                                    │
│   37 │   def duplicate_for_child(self, fd):                                               │
│   38 │   │   self._fds.append(fd)                                                         │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/popen_fork.py:19 in      │
│ __init__                                                                                  │
│                                                                                           │
│   16 │   │   util._flush_std_streams()                                                    │
│   17 │   │   self.returncode = None                                                       │
│   18 │   │   self.finalizer = None                                                        │
│ ❱ 19 │   │   self._launch(process_obj)                                                    │
│   20 │                                                                                    │
│   21 │   def duplicate_for_child(self, fd):                                               │
│   22 │   │   return fd                                                                    │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/popen_forkserver.py:47   │
│ in _launch                                                                                │
│                                                                                           │
│   44 │   │   set_spawning_popen(self)                                                     │
│   45 │   │   try:                                                                         │
│   46 │   │   │   reduction.dump(prep_data, buf)                                           │
│ ❱ 47 │   │   │   reduction.dump(process_obj, buf)                                         │
│   48 │   │   finally:                                                                     │
│   49 │   │   │   set_spawning_popen(None)                                                 │
│   50                                                                                      │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/reduction.py:60 in dump  │
│                                                                                           │
│    57                                                                                     │
│    58 def dump(obj, file, protocol=None):                                                 │
│    59 │   '''Replacement for pickle.dump() using ForkingPickler.'''                       │
│ ❱  60 │   ForkingPickler(file, protocol).dump(obj)                                        │
│    61                                                                                     │
│    62 #                                                                                   │
│    63 # Platform specific definitions                                                     │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/stable_baselines3/common/v │
│ ec_env/base_vec_env.py:371 in __getstate__                                                │
│                                                                                           │
│   368 │   │   self.var = var                                                              │
│   369 │                                                                                   │
│   370 │   def __getstate__(self) -> Any:                                                  │
│ ❱ 371 │   │   return cloudpickle.dumps(self.var)                                          │
│   372 │                                                                                   │
│   373 │   def __setstate__(self, var: Any) -> None:                                       │
│   374 │   │   self.var = cloudpickle.loads(var)                                           │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/cloudpickle/cloudpickle_fa │
│ st.py:73 in dumps                                                                         │
│                                                                                           │
│    70 │   │   │   cp = CloudPickler(                                                      │
│    71 │   │   │   │   file, protocol=protocol, buffer_callback=buffer_callback            │
│    72 │   │   │   )                                                                       │
│ ❱  73 │   │   │   cp.dump(obj)                                                            │
│    74 │   │   │   return file.getvalue()                                                  │
│    75                                                                                     │
│    76 else:                                                                               │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/cloudpickle/cloudpickle_fa │
│ st.py:632 in dump                                                                         │
│                                                                                           │
│   629 │                                                                                   │
│   630 │   def dump(self, obj):                                                            │
│   631 │   │   try:                                                                        │
│ ❱ 632 │   │   │   return Pickler.dump(self, obj)                                          │
│   633 │   │   except RuntimeError as e:                                                   │
│   634 │   │   │   if "recursion" in e.args[0]:                                            │
│   635 │   │   │   │   msg = (                                                             │
╰───────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: cannot pickle '_thread.RLock' object