Closed wangwwno1 closed 5 years ago
Hi! The library does not currently support multi-agent environment interactions directly. Although I hope it is the sort of thing this could be extended to. One way to do it would be to write the multiple agents into one agent, and then use a Composite
action space to collect all of their actions and pass them into the environment. The algorithm would then have access to the multiple agents, as well. Could be a fairly quick thing, or there might be some hidden difficulties...let us know if you try and what happens! Happy to answer more questions or help along the way.
Thanks for your reply! Currently my model works in the way you described (not with this awesome lib yet), but with limited single-GPU memory, I can't use a larger batch size to accelerate training. So I hope there is a way to parallel batch processing.
BTW, I guess the Composite
support heterogeneous (Different shape) action space? That would be great because I can pack actions from different agent(e.g. Predator and Prey) into one object.
Cool! Yes, ‘Compsite’ should be able to support any arbitrary structure of other sub-spaces, including other composites.
Two ways you could go to multi-GPU on this...one would be to put the different agent models on different GPUs, so the agent would hold multiple devices. The other way would be like what’s already in here, to make each model data parallel and use the SyncRl runner.
: )Hello astooke! I'm trying to transfer the model/algo part to this lib, and there are several problems about implementation:
Hi! Ok let's see..
samples.done
signal.get_itr_snapshot()
of the runner class does this by calling the loggers save_itr_params()
(in the train script use logger_context(snapshot_mode="last")
to save the most recent) Currently, this doesn't save the replay buffer or the agent's recurrent state, so you would have to add that. The replay buffer is a bunch of numpy arrays.Hope that helps!
Previoius pickle problem solved by replace namedtuple by dataclass.
To test my enviroment (with slighly mod in obs to fit in single-agent model), I use the original SAC
algo and SacAgent
model, and they work smoothly with SerialSampler
. However, new error raised with GpuSampler
.
Look like key "info" is missing in global(), but I'm not sure where to begin with, any suggestion?
Here is my code:
def build_and_train(env_id="Hopper-v3", run_ID=0, cuda_idx=None, n_parallel=2):
config = dict(
env=dict(id=env_id),
algo=dict(batch_size=128),
sampler=dict(batch_T=2, batch_B=32),
)
sampler = GpuSampler(
EnvCls=gym_make,
env_kwargs=dict(id=env_id),
CollectorCls=GpuWaitResetCollector,
eval_env_kwargs=dict(id=env_id),
max_decorrelation_steps=0,
eval_n_envs=10,
eval_max_steps=int(10e3),
eval_max_trajectories=5,
# batch_T=4, # Get from config.
# batch_B=1,
**config["sampler"] # More parallel environments for batched forward-pass.
)
algo = SAC()
agent = SacAgent()
runner = MinibatchRlEval(
algo=algo,
agent=agent,
sampler=sampler,
n_steps=50e6,
log_interval_steps=1e3,
affinity=dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel))),
)
name = "dqn_" + env_id
log_dir = "example_5"
with logger_context(log_dir, run_ID, name, config):
runner.train()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--env_id', help='Atari id', default='RobertWang_Env:FairMEC-v0')
parser.add_argument('--run_ID', help='run identifier (logging)', type=int, default=0)
parser.add_argument('--cuda_idx', help='gpu to use ', type=int, default=None)
parser.add_argument('--n_parallel', help='number of sampler workers', type=int, default=4)
args = parser.parse_args()
build_and_train(
env_id=args.env_id,
run_ID=args.run_ID,
cuda_idx=args.cuda_idx,
n_parallel=args.n_parallel,
)
And Error Report:
Process Process-2:
Traceback (most recent call last):
File "D:\Toolkits\Anaconda3\envs\py37-th110-gym\lib\multiprocessing\process.py", line 297, in _bootstrap
self.run()
File "D:\Toolkits\Anaconda3\envs\py37-th110-gym\lib\multiprocessing\process.py", line 99, in run
self._target(*self._args, **self._kwargs)
File "D:\Data\GitHub\rlpyt\rlpyt\samplers\buffer.py", line 68, in get_example_outputs
o, r, d, env_info = env.step(a)
File "D:\Data\GitHub\rlpyt\rlpyt\envs\gym.py", line 51, in step
info = info_to_nt(info)
File "D:\Data\GitHub\rlpyt\rlpyt\envs\gym.py", line 85, in info_to_nt
ntc = globals()[name]
KeyError: 'info'
Traceback (most recent call last):
File "D:\Code_Studios\PyCharm 2019.1.3\helpers\pydev\pydevd.py", line 1758, in <module>
main()
File "D:\Code_Studios\PyCharm 2019.1.3\helpers\pydev\pydevd.py", line 1752, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "D:\Code_Studios\PyCharm 2019.1.3\helpers\pydev\pydevd.py", line 1147, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "D:\Code_Studios\PyCharm 2019.1.3\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "D:/Data_Compressed/PyCharm-Workspace/RobertWang's Wheel/Test/example_5.py", line 67, in <module>
n_parallel=args.n_parallel,
File "D:/Data_Compressed/PyCharm-Workspace/RobertWang's Wheel/Test/example_5.py", line 52, in build_and_train
runner.train()
File "D:\Data\GitHub\rlpyt\rlpyt\runners\minibatch_rl.py", line 229, in train
n_itr = self.startup()
File "D:\Data\GitHub\rlpyt\rlpyt\runners\minibatch_rl.py", line 61, in startup
world_size=world_size,
File "D:\Data\GitHub\rlpyt\rlpyt\samplers\parallel\base.py", line 53, in initialize
examples = self._build_buffers(env, bootstrap_value)
File "D:\Data\GitHub\rlpyt\rlpyt\samplers\parallel\gpu\sampler.py", line 54, in _build_buffers
examples = super()._build_buffers(*args, **kwargs)
File "D:\Data\GitHub\rlpyt\rlpyt\samplers\parallel\base.py", line 145, in _build_buffers
agent_shared=True, env_shared=True, subprocess=True)
File "D:\Data\GitHub\rlpyt\rlpyt\samplers\buffer.py", line 29, in build_samples_buffer
all_action = buffer_from_example(examples["action"], (T + 1, B), agent_shared)
File "<string>", line 2, in __getitem__
File "D:\Toolkits\Anaconda3\envs\py37-th110-gym\lib\multiprocessing\managers.py", line 811, in _callmethod
raise convert_to_error(kind, result)
KeyError: 'action'
-------------------Previous Reply-----------------
Just complete my custom gym-env for test.
Everything went well until I use GpuSampler to parallel it - it raised an PicklingError:
_pickle.PicklingError: Can't pickle <class 'RobertWang_Env.Env.UAVCommEnv.FairMEC.general_space'>: attribute lookup general_space on RobertWang_Env.Env.UAVCommEnv.FairMEC failed
Is their anyway to work around it? The "general_space" thing is a namedtuple defined as: GENERAL_SPEC = namedtuple('general_space', ['size', 'movable', 'collide', 'max_speed', 'acceleration', 'color'])
A (possible?) minor enhancement: Can we introduce a single torch.nn.ModuleDict()
to preserve and organize all model in Agent
? That could make load state_dict / set mode on all model a bit easier.
It would be something like:
def __init__(*args, **kwargs):
# Do some stuffs...
# Initialize an ModuleDict for models, it can also apply to optimizer or memory.
# Create Model Dicts
self._model_callbacks = nn.ModuleDict()
self._optimizer_callbacks = OrderedDict()
# Create Memory Dicts
self._memory_callbacks = OrderedDict()
# Do other stuffs...
def initialize(self, env_spaces, share_memory=False, global_B=1, env_ranks=None):
# Do somethings...
self.actor = self.model_Actor(**actor_model_kwargs)
self.actor_target = deepcopy(self.actor)
self.critic = self.model_Critic(**critic_model_kwargs)
self.critic_target = deepcopy(self.critic)
self._model_callbacks.update(OrderedDict(
actor=self.actor,
actor_target=self.actor_target,
critic=self.critic,
critic_target=self.critic_target,
))
Take DDPG structure as an example, before introduction of ModuleDict
:
def eval_mode(self, itr):
"""Go into evaluation mode."""
self.actor.eval()
self.critic.eval()
self.actor_target.eval()
self.critic_target.eval()
self._mode = "eval"
self.distribution.set_std(0.) # Deterministic.
def state_dict(self):
"""Parameters for saving."""
return dict(
actor=self.actor.state_dict()
critic=self.critic.state_dict()
actor_target=self.actor_target.state_dict()
critic_target=self.critic_target.state_dict()
)
def load_state_dict(self, state_dict):
self.actor.load_state_dict(state_dict['actor'])
self.critic.load_state_dict(state_dict['actor']) # Oops
self.actor_target.load_state_dict(state_dict['actor_target'])
self.critic_target.load_state_dict(state_dict['critic_target'])
With self._model_callbacks
:
def eval_mode(self, itr):
"""Go into evaluation mode."""
self._model_callbacks.eval()
self._mode = "eval"
self.distribution.set_std(0.) # Deterministic.
def state_dict(self):
return dict([(key, model.state_dict()) for key, model in self._model_callbacks.items()])
def load_state_dict(self, state_dict):
for key, model in self._model_callbacks.items():
model.load_state_dict(state_dict[key])
Or more directly:
def state_dict(self):
return self._model_callbacks.state_dict()
def load_state_dict(self, state_dict):
self._model_callbacks.load_state_dict(state_dict)
Still, I'm not sure is it compatible with parallel mechanism, but I would take a try : )
Look like the EvalCollector
does not accept reward
with ndim
> 1...
It's easy to return a sum or mean of a reward_batch, but in distributed multi-agent scenario, model might need to accept multiple rewards to calculate q_value for each agent (Especially in a partially observe scenario where agent can only "see" part of the world).
Note: If env
return an sequence of reward, an error will be raise:
File "d:\data\github\rlpyt\rlpyt\runners\minibatch_rl.py", line 231, in train
eval_traj_infos, eval_time = self.evaluate_agent(0)
File "d:\data\github\rlpyt\rlpyt\runners\minibatch_rl.py", line 251, in evaluate_agent
traj_infos = self.sampler.evaluate_agent(itr)
File "d:\data\github\rlpyt\rlpyt\samplers\serial\sampler.py", line 87, in evaluate_agent
return self.eval_collector.collect_evaluation(itr)
File "d:\data\github\rlpyt\rlpyt\samplers\serial\collectors.py", line 57, in collect_evaluation
reward[b] = r
ValueError: setting an array element with a sequence.
Hi! In response to the pickling and 'info' question....this is an awkward part of using namedtuples. They must be defined at the module level (e.g. in a file, outside of any function or class) in order to pickle/unpickle correctly.
Answer 1: Maybe you can get away without pickling. In the GpuSampler, wherever build_samples_buffer
is called, you can try it with kwarg subprocess=False
. This should avoid any pickling. The possible downside is that the NN module will be called with a forward pass before all processes are forked, and this initializes OpenMP/MKL threading, which can be problematic (but try it, for GpuSampler actually it should be fine I think).
Answer 2: In envs/gym.py, when the first instance of the env is created, build_info_tuples(info)
inside the __init__()
creates the env_info nametuple outside the class. Then further instances should recognize that and use it. When it comes time to step the env, it looks for this namedtuple class at the module level, in info_to_nt
. Maybe something went wrong in build_info_tuples
, because info_to_nt
is not able to find the namedtuple class it was supposed to create?
Does that help?
nn.ModuleDict
looks handy, I hadn't seen that before! If it works with DistributedDataParallel
wrappers, then I don't see a reason not to use it.
Hmm, yes, unfortunately reward
is hard-coded as a scalar in a few places, like in the DecorrelatingStartCollector
and the SerialEvalCollector
. Maybe if you change those few places it could work? With luck, algorithm functions like discount_return()
can work with extra dimensions in the reward, as long as time remains the leading dim (but you should check this).
An alternative would be to put the multiple rewards into the env_info, and be sure to store that in the replay buffer.
Hi! In response to the pickling and 'info' question....this is an awkward part of using namedtuples. They must be defined at the module level (e.g. in a file, outside of any function or class) in order to pickle/unpickle correctly.
Answer 1: Maybe you can get away without pickling. In the GpuSampler, wherever
build_samples_buffer
is called, you can try it with kwargsubprocess=False
. This should avoid any pickling. The possible downside is that the NN module will be called with a forward pass before all processes are forked, and this initializes OpenMP/MKL threading, which can be problematic (but try it, for GpuSampler actually it should be fine I think).Answer 2: In envs/gym.py, when the first instance of the env is created,
build_info_tuples(info)
inside the__init__()
creates the env_info nametuple outside the class. Then further instances should recognize that and use it. When it comes time to step the env, it looks for this namedtuple class at the module level, ininfo_to_nt
. Maybe something went wrong inbuild_info_tuples
, becauseinfo_to_nt
is not able to find the namedtuple class it was supposed to create?Does that help?
Hello astooke, I tried the suggestion in Answer 1, and error message has a bit different:
Traceback (most recent call last):
File "D:\Code_Studios\PyCharm 2019.1.3\helpers\pydev\pydevd.py", line 1758, in <module>
main()
File "D:\Code_Studios\PyCharm 2019.1.3\helpers\pydev\pydevd.py", line 1752, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "D:\Code_Studios\PyCharm 2019.1.3\helpers\pydev\pydevd.py", line 1147, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "D:\Code_Studios\PyCharm 2019.1.3\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "D:/Data_Compressed/PyCharm-Workspace/RobertWang's Wheel/Test/example_2.py", line 74, in <module>
n_parallel=args.n_parallel,
File "D:/Data_Compressed/PyCharm-Workspace/RobertWang's Wheel/Test/example_2.py", line 59, in build_and_train
runner.train()
File "d:\data\github\rlpyt\rlpyt\runners\minibatch_rl.py", line 229, in train
n_itr = self.startup()
File "d:\data\github\rlpyt\rlpyt\runners\minibatch_rl.py", line 61, in startup
world_size=world_size,
File "d:\data\github\rlpyt\rlpyt\samplers\parallel\base.py", line 71, in initialize
w.start()
File "D:\Toolkits\Anaconda3\envs\py37-th110-gym\lib\multiprocessing\process.py", line 112, in start
self._popen = self._Popen(self)
File "D:\Toolkits\Anaconda3\envs\py37-th110-gym\lib\multiprocessing\context.py", line 223, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "D:\Toolkits\Anaconda3\envs\py37-th110-gym\lib\multiprocessing\context.py", line 322, in _Popen
return Popen(process_obj)
File "D:\Toolkits\Anaconda3\envs\py37-th110-gym\lib\multiprocessing\popen_spawn_win32.py", line 89, in __init__
reduction.dump(process_obj, to_child)
File "D:\Toolkits\Anaconda3\envs\py37-th110-gym\lib\multiprocessing\reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <class 'rlpyt.utils.collections.info'>: attribute lookup info on rlpyt.utils.collections failed
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "D:\Toolkits\Anaconda3\envs\py37-th110-gym\lib\multiprocessing\spawn.py", line 105, in spawn_main
exitcode = _main(fd)
File "D:\Toolkits\Anaconda3\envs\py37-th110-gym\lib\multiprocessing\spawn.py", line 115, in _main
self = reduction.pickle.load(from_parent)
EOFError: Ran out of input
Similiar error happens in original Multi-Agent Env + My Customized Model. Most traceback are exactlly same(with minor difference in the File path of my test.py), so I only post the difference. When subprocess=True:
AttributeError: Can't get attribute 'obs_others_mask' on <module 'rlpyt.spaces.gym_wrapper' from 'd:\\data\\github\\rlpyt\\rlpyt\\spaces\\gym_wrapper.py'>
When subprocess=False:
_pickle.PicklingError: Can't pickle <class 'rlpyt.utils.collections.obs'>: attribute lookup obs on rlpyt.utils.collections failed
Hmm I haven't seen this behavior before when forking with namedtuples. Is there a way you can share your multi-agent env with me so I can try running some things? Or better, whatever is the simplest code which reproduces the errors?
Hmm I haven't seen this behavior before when forking with namedtuples. Is there a way you can share your multi-agent env with me so I can try running some things? Or better, whatever is the simplest code which reproduces the errors?
Hmm, maybe start method count. Currently I code and test in Windows 10, by default they will use spawn
to start process (not fork
).
Here is my Custom Env with a Test.py for Single-Agent (with Default SAC from lib)
RobertWang-Env.zip
Before running the code, you might need to register the env by pip install -e .
in the folder where setup.py
located.
I'm still working for the Multi-Agent's part of agent and algo, but you can test the Multi-Agent part of this enviroment by following these steps:
self.observation_space
and self.action_space
in _init_space(self)
with Multi-Agent version (See commented code)observation
in observation(self)
with Multi-Agent version (See commented code)action
in 2D float ndarray. By default the action
array should have shape=(10, 5)
and range=[0. , 1.]
.
nn.ModuleDict
looks handy, I hadn't seen that before! If it works withDistributedDataParallel
wrappers, then I don't see a reason not to use it.
Unfortunately, it (probably) can't.
After examined the source code of DistributedDataParallel
, it looks like that we cannot directly pass nn.ModuleDict
to DistributedDataParallel
. Unlike cuda
or to(device)
, DistributedDataParallel
does not modify model
in-place, instead, it would store the nn.Module
(in this case, nn.ModuleDict
) itself, return an DDP object, and distribute data in its forward
pass.
Since nn.ModuleDict
store the reference to the model
, if we wrap model
with DistributedDataParallel
, the nn.ModuleDict
must be updated accordingly. Like:
self.actor = DDP(self.actor)
self.critic = DDP(self.critic)
self._model_callbacks.update(OrderedDict(
actor=self.actor,
critic=self.critic,
))
Note: According to this doc , it's recommended to update ModuleDict with Ordered Map like OrderedDict.
The "general_space" thing is a namedtuple defined as: GENERAL_SPEC = namedtuple('general_space', ['size', 'movable', 'collide', 'max_speed', 'acceleration', 'color'])
Actually, the reason that can't pickle, is because the variable name in the file and the first arg to namedtuple
must be the same for pickling to work, like so:
GeneralSpace=namedtuple('GeneralSpace',...)
Sorry I didn't read that more closely before! Give this a try?
As for the multi-agent observation, you might need to make a custom rlpyt GymEnvWrapper
, and possibly GymSpaceWrapper
, because you might need to make the global_observation_space
also into an rlpyt-style space, the same way the existing wrappers are used on the regular observation_space
. Reading your file, this is my first guess, let me know if this helps?
Currently I code and test in Windows 10, by default they will use spawn to start process (not fork).
I haven't tested anything in Windows, so it's possible something else will come up.
Actually, the reason that can't pickle, is because the variable name in the file and the first arg to
namedtuple
must be the same for pickling to work, like so:GeneralSpace=namedtuple('GeneralSpace',...)
Sorry I didn't read that more closely before! Give this a try?
👍 Thanks for your persistent tracking on this issue, astooke!
Sorry for I didn't make it more clearly - this pickling error (about general_space) is already solved.
Rename the variable name
cannot solve this problem, however. So I replace the namedtuple
with a dataclass
. It's a class specifically designed to store object attributes.
As for the multi-agent observation, you might need to make a custom rlpyt
GymEnvWrapper
, and possiblyGymSpaceWrapper
, because you might need to make theglobal_observation_space
also into an rlpyt-style space, the same way the existing wrappers are used on the regularobservation_space
. Reading your file, this is my first guess, let me know if this helps?
Currently there is no need to customize a GymEnvWrapper
.
The observation
returnd by environment is defined in observation_space
, NOT global_observation_space
(like samples_from_replay
and ReplayBuffer.samples
: samples_from_replay
is a reorganized portion of ReplayBuffer.samples
).
Still, good suggestion for future expansion! : )
I haven't tested anything in Windows, so it's possible something else will come up.
That's OK, whether these error happen or not, my next move is "transfer my code to a Ubuntu server" (laugh) I believe these errors will be disapper by changing the OS. I'm still working on it, If it actually work, I will informed you asap : )
About the whole AttributeError
or PicklingError
problem, here is my thought:
Windows don't support fork
start method, the default for Unix OS, it use spawn
as start method instead, and according to this doc, they work differently:
spawn
- The child process will only inherit those resources necessary to run the process objects run() method.
- Unnecessary file descriptors and handles from the parent process will NOT be inherited.
fork
- The child process, when it begins, is effectively identical to the parent process.
- All resources of the parent are inherited by the child process.
namedtuple
is a factory method(defined on-fly). So it's possible that spawn
method does not pass those namedtuple
, which defined in parent process, to the child.
Remember the Pickling Error
msg:
_pickle.PicklingError: Can't pickle <class 'RobertWang_Env.Env.UAVCommEnv.FairMEC.general_space'>: attribute lookup general_space on RobertWang_Env.Env.UAVCommEnv.FairMEC failed
By definition the general_space
is outside of the FairMEC
(the environment) class, Without information, the "spawned" child process mistaken these namedtuple
as an attribute of class.
"Forked" child process, however, will receive defined namedtuple
from the parent process, thus solve this problem.
In the past I've gotten the same error if the assigned object name and the first argument into named tuple don't match. For example, abc = namedtuple("xyz", "one, two")
will cause the pickling error, but abc = namedtuple("abc", "one, two")
works. Did you try this?
In the past I've gotten the same error if the assigned object name and the first argument into named tuple don't match. For example,
abc = namedtuple("xyz", "one, two")
will cause the pickling error, butabc = namedtuple("abc", "one, two")
works. Did you try this?
I already tried this in my environment:
GeneralSpec = namedtuple('GeneralSpec', ["spec1", "spec2"])
I remember it will raise PicklingError
in Windows 10, but I havn't try it in Linux, nor in WSL(Windows Subsystem for Linux), so I'm not sure what will happen in Linux.
Maybe it is a system compatibility issue? Like my former post guess, caused by different process start method?
Ok interesting to know about fork vs spawn.
Hopefully the move to Linux will make things work, good luck and let us know!
Hello astooke, after transfer to Ubuntu and use fork
start method, the PicklingError
disappeared. It will be appear again by setting 'spawn' as start method (Unix support spawn
and fork
), so I think they are closely related.
However, parallelism with GPU does not work. When moving models to cuda device, there is an error message without traceback:
THCudaCheck FAIL file=/tmp/pip-req-build-58y_cjjl/aten/src/THC/THCGeneral.cpp line=54 error=3 : initialization error
After that, the process will be "freeze", is there anything wrong in my script?
Edit:
Another possibility is in this PyTorch doc, it said CUDA Runtime does not support fork
as process start method.
But I still want to achieve data parallel among multi-GPU - one model copy to other GPUs and everything in one process, any good suggestion?
Another Edit:
torch.nn.DataParallel
is great for Single process with Multi-GPU.
That's good enough for my training!
Here is the script:
def build_and_train(env_id, run_ID, cuda_idx=None):
affinity = make_affinity(
run_slot=0,
n_cpu_core=12,
n_gpu=2,
gpu_per_run=2,
)
sampler = CpuSampler(
EnvCls=gym_make,
env_kwargs=dict(id=env_id),
eval_env_kwargs=dict(id=env_id),
CollectorCls=CpuWaitResetCollector,
batch_T=100,
batch_B=12,
max_decorrelation_steps=0,
eval_n_envs=12,
eval_max_steps=int(16 * 1e2),
eval_max_trajectories=5,
)
algo = MyModel(
discount=0.99,
batch_size=128, # 32 * 200 (~15900 MB), 96 * 100 (14220 MB)
min_steps_learn=int(6e2),
replay_size=int(1e3),
replay_ratio=128, # data_consumption / data_generation
learning_rate=1e-4,
q_learning_rate=1e-3,
OptimCls=torch.optim.Adam,
optim_kwargs=None,
initial_optim_state_dict=None,
clip_grad_norm=1e8,
q_target_clip=1e6,
n_step_return=1,
pre_train=True,
scene_memory_size=int(1e2),)
agent = MyAgent()
runner = SyncRl(
algo=algo,
agent=agent,
sampler=sampler,
n_steps=1e4,
log_interval_steps=1e3,
affinity=affinity,
)
config = dict(env_id=env_id)
name = "sac_" + env_id
log_dir = "RWExp"
with logger_context(log_dir, run_ID, name, config):
runner.train()
Hello again! I solved the training problem by using torch.nn.DataParallel
. : )
Now I'm trying to working out the logger
- it give me error logs in rlpyt/data/local
, but I can't find my trained model parameter and episode reward.
Is there any shortcut for saving model parameter & record rewards?
Hello again! I solved the training problem by using
torch.nn.DataParallel
. : ) Now I'm trying to working out thelogger
- it give me error logs inrlpyt/data/local
, but I can't find my trained model parameter and episode reward. Is there any shortcut for saving model parameter & record rewards?
Ah, sorry, I found it in logger_context
, now I can start my training! : )
A tip: it would be better if logger can store different files (e.g. debug logs, model params) in different folder, and allow user to specify saving path outside of this great lib!
Another tip: When user run same experiment file twice, the latter experiment will overwrite the former record, but sometime it's not a desired behavior (User want to have two records, but forget to change log_path). We can use an incremental suffix to solve this problem, if certain suffix exist, the logger will change it to find a unused path, thus store experiments (same file, running in different time, maybe with different params) in different paths.
Here is a prototype from my project, hope that helps! : )
def get_output_folder(parent_dir, env_name):
"""Return save folder.
Assumes folders in the parent_dir have suffix -run{run
number}. Finds the highest run number and sets the output folder
to that number + 1. This is just convenient so that if you run the
same script multiple times tensorboard can plot all of the results
on the same plots with different names.
Parameters
----------
parent_dir: str
Path of the directory containing all experiment runs.
Returns
-------
parent_dir/run_dir
Path to this run's save directory.
"""
os.makedirs(parent_dir, exist_ok=True)
experiment_id = 0
for folder_name in os.listdir(parent_dir):
if not os.path.isdir(os.path.join(parent_dir, folder_name)):
continue
try:
folder_name = int(folder_name.split('-run')[-1])
if folder_name > experiment_id:
experiment_id = folder_name
except :
pass
experiment_id += 1
parent_dir = os.path.join(parent_dir, env_name)
parent_dir = parent_dir + '-run{}'.format(experiment_id)
os.makedirs(parent_dir, exist_ok=True)
return parent_dir
Hello astooke! Here is an enhancement proposal! During the warm-up process(filling replay memory without actual training), the evaluation could cost 1/3 of time and spam a lot of un-trained model.pkl. Use self.algo.min_iter_learn to control it would save a lot of time (and reduce hair loss : ) ) Much better if we can add an separate progress bar(or log) to monitor replay collection progress.
Hi! OK a few points to respond to...
As for torch.nn.DataParallel, I'm not sure the error that's coming back from the built-in DistributedDataParallel setup. Using the SyncRl
runner should fork multiple processes before initializing any CUDA, so that each process sets up its own CUDA. Does your agent have more different models which need to go to the device? Can show us that code? Should try to follow the pattern in the existing agents for doing so. Cool that DataParallel is working for you! You probably don't need the SyncRl
runner for that, but can just use the regular MinibatchRl
, which keeps to just one python process.
Also something else I just noticed is that you probably want to have the log_interval_steps
be at least as large as the sampler batch size (batch_B * batch_T), since you can only log once per iteration anyway (I usually will log only once per several iterations). Also might not be desirable to have a replay_size
smaller than the batch size, and a min_steps_learn
smaller than the batch size.
As for overwriting the log files....yes in more recent work I've had this problem and solved it by extending the date/time-stamp from only YYYYMMDD to YYYYMMDD-HHMMSS (down to the second). So each time you launch goes into a different folder. But incrementing a run counter is nice because that's a more meaningful label!
To write to a different log directory, you can just change this line: https://github.com/astooke/rlpyt/blob/8331b7f919bcecd3dfe9ece85d2fb479d471ff6f/rlpyt/utils/logging/context.py#L10 or otherwise modify the logger context as you wish.
For the last point of running evaluations before the agent even starts learning...yes I've usually avoided this by setting log_interval_steps
greater than or equal to min_steps_learn
. Otherwise you could modify the logger to check for this property when deciding whether to log, getattr(self.algo, "min_steps_learn", 0)
or something like that.
Hope this helps!
As for torch.nn.DataParallel, I'm not sure the error that's coming back from the built-in DistributedDataParallel setup. Using the
SyncRl
runner should fork multiple processes before initializing any CUDA, so that each process sets up its own CUDA. Does your agent have more different models which need to go to the device? Can show us that code? Should try to follow the pattern in the existing agents for doing so.
: ) Yes, my agent has multiple models, but I'm not sure whether are they correctly passed to shared_models. Currently I use nn.ModuleDict
to organize these models:
self._model_callbacks = torch.nn.ModuleDict(
spatial_encoder=S_Encoder(**kwargs),
temporal_encoder=T_Encoder(**kwargs),
actor=Actor(**kwargs),
critic=Critic(**kwargs),
# And their target models, which are deepcopies of the four forementioned model : )
)
self._shared_model_callbacks = torch.nn.ModuleDict() # Empty at start, will define later in initialize.
So there are 4 * 2 = 8 models. To simplify reference, I define multiple properties in agent
so I can directly use self.actor
to ref self._model_callbacks['actor']
.
@property
def actor(self):
return self._model_callbacks['actor']
# This applied to ALL models in self._model_callbacks
Apart from these encoder models, the agent is almost identical DDPGAgent
, with self.actor
replaced self.model
and self.critic
replaced self.q_model
, so does their target networks.
The major difference is, in DDPGAgent
both actor
and critic
receive observation directly, in my agent, however, the observation will be encoded first, then pass to actor
and critic
.
So, compared with vanilla DDPG, my "DDPGAgent" is something like:
self.model = nn.Sequential(
self.spatial_encoder,
self.temporal_encoder,
self.actor
)
self.q_model = nn.Sequential(
self.spatial_encoder, # Same model as the one in self.model
self.temporal_encoder, # Same model as the one in self.model
self.critic,
)
# Similiar definition applied to ther target counterpart.
So far everything is fine, but when it come to multiprocess, things become tricky.
Since sampler
and evaluator
shall use MyAgent.step
to generate action
from observation
, I put these three models(self.spatial_encoder
, self.temporal_encoder
, and self.actor
) in self.shared_model_callbacks
:
# self._shared_model_keys = ('spatial_encoder', 'temporal_encoder', 'actor')
if share_memory:
update_list = []
for key in self._shared_model_keys:
self._model_callbacks[key].share_memory()
update_list.append((key, self._model_callbacks[key]))
self._shared_model_callbacks.update(update_list)
del update_list
I didn't put self.critic
in self._shared_model_callbacks
- self.critic
is used to determine the update strength of self.actor
, and is irrelevant to MyAgent.step
.
When there is a need to move model:
def to_device(self, cuda_idx=None):
# TODO Test
"""Overwite/extend for format other than 'self.model' for network(s)."""
if cuda_idx is None:
return
if len(self._shared_model_callbacks) > 0:
# I tried two ways to take care these model.
# The first way is re-initialize specific models with stored **kwargs
# self.spatial_encoder, self.temporal_encoder, and self.actor are involved.
# update_keys = self._shared_model_callbacks.keys()
# self._initialize_models(update_keys=update_keys)
# The other way is to directly copy the whole self._model_callbacks.
self._model_callbacks = deepcopy(self._model_callbacks)
# update state_dict()
for key, shared_model in self._shared_model_callbacks.items():
model = self._model_callbacks[key]
update_state_dict(model, shared_model.state_dict(), strip_ddp=True)
self.device = torch.device("cuda", index=cuda_idx)
self._model_callbacks.to(self.device) # Initialization Error
logger.log(f"Initialized agent model on device: {self.device}.")
def data_parallel(self):
# TODO Test
def update_model_callbacks(para_method):
for key, model in self._model_callbacks.items():
if not key.endswith("_target"):
# Target network won't require grads, so don't parallel them.
self._model_callbacks[key] = para_method(model)
if self.device.type == "cpu":
update_model_callbacks(DDPC)
logger.log("Initialized DistributedDataParallelCPU agent model.")
else:
update_model_callbacks(lambda x: DDP(x, device_ids=[self.device.index], output_device=self.device.index))
logger.log("Initialized DistributedDataParallel agent model on " f"device {self.device}.")
def sync_shared_memory(self):
"""Call in sampler master (non-async), after initialize(share_memory=True)."""
for key, shared_model in self._shared_model_callbacks.items():
model = self._model_callbacks[key]
if shared_model is not model:
# (shared_model gets trained)
update_state_dict(shared_model, model.state_dict(), strip_ddp=True)
However, this line will always raise an initialization error.
self._model_callbacks.to(self.device)
It will be log into debug file, but won't terminate main process.
I think nn.ModuleDict
is just a specialized OrderedDict, it should be OK to ref a model reference from it. So perhaps there is something wrong in the definition of self._shared_model_callbacks
.
Cool that DataParallel is working for you! You probably don't need the SyncRl runner for that, but can just use the regular MinibatchRl, which keeps to just one python process.
nn.DataParallel
+ MinibatchRl = Humming GPUs + Happy Researcher🎉
Still, I would like to provide assistance to resolve this issue.
Let me know if further information is needed.
For the last point of running evaluations before the agent even starts learning...yes I've usually avoided this by setting log_interval_steps greater than or equal to min_steps_learn. Otherwise you could modify the logger to check for this property when deciding whether to log, getattr(self.algo, "min_steps_learn", 0) or something like that.
Setting log_interval_steps greater than or equal to min_steps_learn is not a perfect solution, especially with large experience replay. (e.g. in my scenario it would took 1000 itr to collect 1000 episode experience, each itr = 100 steps, and it also take another 1000 itr to train the model. If I apply this setting, I won't be able to monitor the training progress because it only generate ONE log at the end of the training.)
I make a small patch in the MiniBatchRl
runner, since it's runner
control when and how to run evaluation and record debug logs. There are other runners with their unique logging functions, but all of them are inherited from BaseRunner
class, so extract logging functions to BaseRunner
and patch it once for all would be a better solution.
: )Hi, astooke, mind me ask some dumb question about the clip_grad_norm_
?
I have noticed that there are clip_gradnorm in the algo
part, as it do gradient normalization and record values like `q_grad_norm' in the progress.csv. It's the first time I use grad_norm in my model, so I have two questions about them:
grad_norm
in progress.csv, and how to use them to evaluate the training progress.self.clip_grad_norm
in algo
?: ) Hello astooke, the training is finished and is a great💯 success.
Many thanks to you and this awesome libs! 🎉
Since the original problem is resolved, I will close this issue, and make a summary about how to handle the output of Multi-Agent environment in the first post. Once the paper is finished, I would like to contribute a citation to the whitepaper of this great lib.
Have a good day, and Happy Reinforcement Learning!
Hello, does the lib support multi-agent environment? Or more precisely, allow multiple agents share environment state, select their action in parallel, then return the combined actions to the environment? -----------------Edit----------------- After multiple tries, I figure out some tips for training with Multi-Agent environment.
How to pass multi-agent observation
If there is only one model for all agent, simply pack all observation into one array, and pretend it as a single mega agent environment. If there are multiple models, follow same procedure, and also devise
algo
andagent
part of your code. It's recommended to usetorch.nn.ModuleList
ortorch.nn.ModuleDict
to organize multiple models, then apply function in parallel to each model.How to pass multiple reward values
A typical
Gym
environment step return should be a four elements tuple:observation, reward, done, info
. Thereward
in the return ofstep
must be a scalar because evaluation need it to calculate total episode reward. However, sometimes you may want to have a unique reward for each agent, which must be an 1d array. The key point of solution is to passing your actual reward from another output other thanreward
. To resolve this problem, modify the enviroment with:Then, in
algo
part of your code, modifyinitialize_replay_buffer
,samples_to_buffer
and any functions that relevant to the conversion betweensamples
andbuffer
:After that, your algo would receive
actual_reward
(which is an array or nested array) instead of scalarreward
.