astooke / rlpyt

Reinforcement Learning in PyTorch
MIT License
2.23k stars 324 forks source link

Proper way to re-load a replay buffer, for resuming training (of agents that use a replay buffer) #111

Closed DanielTakeshi closed 4 years ago

DanielTakeshi commented 4 years ago

Hi @astooke I am trying to write some code that will show how to save and load models for resuming training (not just for inference). Assuming we have algorithms that use replay buffers, this may require more coding because we need to also load the replay buffer used at the time we saved it.

The most naive way to save and load a replay buffer does not seem to work. I cloned the repo, made these changes:

(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $ git diff
diff --git a/linux_cuda10.yml b/linux_cuda10.yml
index fb73927..68ea0c2 100644
--- a/linux_cuda10.yml
+++ b/linux_cuda10.yml
@@ -1,4 +1,4 @@
-name: rlpyt
+name: rlpyt-astooke
 channels:
 - pytorch
 dependencies:
diff --git a/rlpyt/runners/minibatch_rl.py b/rlpyt/runners/minibatch_rl.py
index 6850b93..7dd4741 100644
--- a/rlpyt/runners/minibatch_rl.py
+++ b/rlpyt/runners/minibatch_rl.py
@@ -296,6 +296,18 @@ class MinibatchRlEval(MinibatchRlBase):
         specified log interval.
         """
         n_itr = self.startup()
+
+        import pickle, sys
+        replay_buffer = self.algo.replay_buffer
+        print('saving buffer ...')
+        with open('buffer.pkl', 'wb') as fh:
+            pickle.dump(replay_buffer, fh, protocol=4)
+        print('done with saving buffer. now let us load ...')
+        with open('buffer.pkl', 'rb') as fh:
+            replay_buffer = pickle.load(fh)
+        print('buffer loaded')
+        sys.exit()
+
         with logger.prefix(f"itr #0 "):
             eval_traj_infos, eval_time = self.evaluate_agent(0)
             self.log_diagnostics(0, eval_traj_infos, eval_time)
(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $ 

and then ran the script. It seems to have saved and loaded:

(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $ python examples/example_5.py 
logger_context received log_dir outside of /home/seita/rlpyt_astooke/data: prepending by /home/seita/rlpyt_astooke/data/local/<yyyymmdd>/
2020-02-25 12:48:50.955948  | dqn_pong_0 Runner  master CPU affinity: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15].
2020-02-25 12:48:50.956036  | dqn_pong_0 Runner  master Torch threads: 8.
using seed 4803
2020-02-25 12:48:50.958320  | dqn_pong_0 Total parallel evaluation envs: 10.
using seed 4805
using seed 4804
2020-02-25 12:48:51.552859  | dqn_pong_0 Sampler rank 1 initialized, CPU affinity [1], Torch threads 1, Seed 4805
2020-02-25 12:48:51.552953  | dqn_pong_0 Sampler rank 0 initialized, CPU affinity [0], Torch threads 1, Seed 4804
2020-02-25 12:48:54.984253  | dqn_pong_0 Sampler decorrelating envs, max steps: 0
2020-02-25 12:48:55.877700  | dqn_pong_0 Running 750000 iterations of minibatch RL.
2020-02-25 12:48:55.878391  | dqn_pong_0 From sampler batch size 64, training batch size 128, and replay ratio 8, computed 4 updates per iteration.
2020-02-25 12:48:55.878459  | dqn_pong_0 Agent setting min/max epsilon itrs: 781, 15625
2020-02-25 12:48:55.880199  | dqn_pong_0 Frame-based buffer using 4-frame sequences.
saving buffer ...
done with saving buffer. now let us load ...
buffer loaded

the buffer is here and is 16G in size (it's pre-allocated):

(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $ ls -lh buffer.pkl 
-rw-rw-r-- 1 seita seita 16G Feb 25 12:49 buffer.pkl
(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $ 

However I cannot do the same re-loading if I am not in the same exact python call. For example:

(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $ ipython
Python 3.7.5 (default, Oct 25 2019, 15:51:11) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.10.2 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import pickle                                                                                                              

In [2]: with open('buffer.pkl', 'rb') as fh: 
   ...:     buffer = pickle.load(fh) 
   ...:                                                                                                                            
---------------------------------------------------------------------------
UnpicklingError                           Traceback (most recent call last)
<ipython-input-2-97de235d15c6> in <module>
      1 with open('buffer.pkl', 'rb') as fh:
----> 2     buffer = pickle.load(fh)
      3 

UnpicklingError: NEWOBJ class argument isn't a type object

In [3]: exit                                                                                                                       
(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $ 

I believe this has to do with the general difficulties of saving entire classes through pickle files. (The torch save/load, which use pickles, suggest something similar here.)

As of now it seems like the way I should save and load replay buffers, at a high level, would be:

I will need to double check if there are variables we need to track that might depend on time (e.g., prioritization terms). If you have any advice that would be great. Thanks!

astooke commented 4 years ago

Ok really good question and I haven't tried this before.

One possible difficulty is in the use of the namedarraytuple which uses class definitions that can lead to some pickling difficulties, recently talked about in #99.

A solution was suggested in that issue and I just posted a first attempt at building that out...check out the namedtuple_schema branch. It has classes (https://github.com/astooke/rlpyt/blob/namedtuple_schema/rlpyt/utils/namedtuple_schema.py) for making objects that behave like namedarraytuples without requiring the class definition. Hopefully not too painful to grab that and replace every X=namedtuple(...) and Y=namedarraytuple(..) definition with X=NamedTupleSchema(...) and Y=NamedArrayTupleSchema(...). Then retry the pickle test? Hopefully this would be the most universal solution.

I did work previously to make sure namedarraytuple can be pickled/unpickled, for example by always defining them at the module level and never nested inside another class...but I may have fallen short.

EDIT: oh you might also need to go into buffer_from_example() and prevent it from using namedarraytuple_like(), but have it do the NamedArrayTuple equivalent (which I haven't written yet)

astooke commented 4 years ago

OK with pushes to master today, the NamedArrayTuple is in place there, might be worth giving this data structure a try and see if it resolves the pickling issue. And if so, can go ahead and drop it in for namedarraytuple everywhere, I think by using the new kwarg in buffer_from_example(..., use_NatSchema=True) wherever your samples or replay buffers are allocated.

jordan-schneider commented 4 years ago

I'm going to try this, as I need the functionality as well. Having to save an entire 16 GB pre-allocated buffer also seems like a mistake, so I'm also going to spend some time thinking about how to save only the non-empty elements in a non-full buffer.

jordan-schneider commented 4 years ago

@astooke you have a bug in NamedTupleSchema where if you use a string with delimiter as the fields parameter, it tries to use the separator as a field. Do you want this fix for this as a separate PR, or should I submit it alongside the replace-all PR?

jordan-schneider commented 4 years ago

E.g. in class BatchSpec(NamedTupleSchema("BatchSpec", "T B")): NamedTupleSchema tries to parse the middle space as a separate field.

jordan-schneider commented 4 years ago

Actually the example pickle code seems to have worked on main. There are still namedtuples and namedarraytuples running around the code base. Did you just replace the ones necessary to pickle buffers?

DanielTakeshi commented 4 years ago

@jordan-schneider Sorry, I wasn't able to look at this again recently. Have you been able to make progress on it?

astooke commented 4 years ago

@jordan-schneider Yes I just replaced the ones for pickling buffers. So the original definitions can stay with namedarraytuple but the buffers are created as NamedArrayTuple anyway. The exception are the gym wrapper classes (space and env) which dynamically define named-array-tuple classes.

Thanks for the PR!

This all working now?

jordan-schneider commented 4 years ago

AFAICT, yes, but my case is slightly different from Daniels.

astooke commented 4 years ago

@DanielTakeshi Curious if this ended up working for you, or if we still need to patch something?

DanielTakeshi commented 4 years ago

Sorry @astooke but I haven't had time to go back to my use case for the above [which is slightly more complex than what I actually reported in the issue] and I might not get around to it for a while.

DanielTakeshi commented 4 years ago

For now we can close, I'll re-open if needed later. Thanks for your help.