astooke / rlpyt

Reinforcement Learning in PyTorch
MIT License
2.22k stars 323 forks source link

No module named 'rlpyt.samplers.cpu' #15

Closed tzahishimkin closed 4 years ago

tzahishimkin commented 5 years ago

Running 'mujoco_ff_a2c_cpu.py' and getting this error. Can't find 'cpu' directory in 'samples', only inside the subdirectory 'parallel'

tzahishimkin commented 5 years ago

I might be missing on something, but I am unable to locate the following libraries:

from rlpyt.samplers.gpu.parallel_sampler import GpuParallelSampler from rlpyt.samplers.gpu.collectors import WaitResetCollector from rlpyt.runners.minibatch_rl_eval import MinibatchRlEval

mlpanda commented 5 years ago

Just adding to this comment; I tried running the r2d1 experiment for Atari, and constantly ran into functions which did not exist. Initially I debugged and found that it was mainly due to changes in function names, however at some stage I ran into larger changes, hence this comment. @astooke I would very much appreciate if you could look through the code and update the experiments to new changes elsewhere in the repo - and perhaps add an r2d1 example file :) Thanks in advance.

astooke commented 5 years ago

Hi! Yes indeed, the experiments folder has a bunch of leftover scripts written over the course of development, as things were changing. Definitely that should be cleaned up, to leave the only the current running examples. Will do soon! Sorry about the confusion.

mlpanda commented 5 years ago

Thanks @astooke, much appreciated! :-)

tzahishimkin commented 5 years ago

Thanks allot!

So which example code do you use to run the DDPG model?

astooke commented 5 years ago

Hi! Sorry lost this for a while...but a current working example should be:

https://github.com/astooke/rlpyt/blob/master/rlpyt/experiments/scripts/mujoco/qpg/launch/got/launch_mujoco_ddpg_serial.py

which calls

https://github.com/astooke/rlpyt/blob/master/rlpyt/experiments/scripts/mujoco/qpg/train/mujoco_ddpg_serial.py

DanielTakeshi commented 4 years ago

(I think the launch script you're referring to shouldn't be in the got subdirectory -- incidentally what does that name represent?)

@astooke I have several quick related questions.

I am trying to understand the experiments directory. These are serial as judged by the file names, and I see the serial sampler here:

https://github.com/astooke/rlpyt/blob/75e96cda433626868fd2a30058be67b99bbad810/rlpyt/experiments/scripts/mujoco/qpg/train/mujoco_ddpg_serial.py#L24-L30

However, the launch file has this:

https://github.com/astooke/rlpyt/blob/75e96cda433626868fd2a30058be67b99bbad810/rlpyt/experiments/scripts/mujoco/qpg/launch/launch_mujoco_ddpg_serial.py#L7-L13

I'm a bit confused because serial implies we have one python process running everything. However, the launch file suggests we can have lots of CPUs and lots of GPUs used. However this would suggest to me using multiple processes. I am wondering, what is the proper way to interpret n_cpu_core, n_gpu_core, and cpu_per_run in this context?

DanielTakeshi commented 4 years ago

Just a quick note that none of the four current policy gradient scripts for MuJoCo appear to be up to date with the latest changes on master today:

Screenshot from 2019-11-21 09-34-50

DanielTakeshi commented 4 years ago

I finally managed to figure out how to parallelize the DDPG for MuJoCo, because only the serial ones were available. Here is a working launch file for the CPU:

from rlpyt.utils.launching.affinity import encode_affinity
from rlpyt.utils.launching.exp_launcher import run_experiments
from rlpyt.utils.launching.variant import make_variants, VariantLevel

script = "rlpyt/experiments/scripts/mujoco/qpg/train/mujoco_ddpg_cpu.py"
affinity_code = encode_affinity(
    n_cpu_core=8,
    n_gpu=1,
    hyperthread_offset=2,
    n_socket=1,
    cpu_per_run=2,
)
runs_per_setting = 1
default_config_key = "ddpg_from_td3_1M_parallel_cpu"
experiment_title = "mujoco_ddpg_parallel_cpu"
variant_levels = list()

env_ids = ["Hopper-v2"]  # , "Swimmer-v3"]
values = list(zip(env_ids))
dir_names = ["env_{}".format(*v) for v in values]
keys = [("env", "id")]
variant_levels.append(VariantLevel(keys, values, dir_names))

variants, log_dirs = make_variants(*variant_levels)

run_experiments(
    script=script,
    affinity_code=affinity_code,
    experiment_title=experiment_title,
    runs_per_setting=runs_per_setting,
    variants=variants,
    log_dirs=log_dirs,
    common_args=(default_config_key,),
)

and the corresponding training script:

"""
Run DDPG using sampling on the CPU, with parallelism.
"""
import sys
from rlpyt.utils.launching.affinity import affinity_from_code
#from rlpyt.samplers.cpu.parallel_sampler import CpuParallelSampler  # outdated
#from rlpyt.samplers.cpu.collectors import ResetCollector  # outdated
from rlpyt.samplers.parallel.cpu.sampler import CpuSampler  # correct replacement
from rlpyt.samplers.parallel.cpu.collectors import CpuWaitResetCollector # correct replacement
from rlpyt.envs.gym import make as gym_make
from rlpyt.algos.qpg.ddpg import DDPG
from rlpyt.agents.qpg.ddpg_agent import DdpgAgent
from rlpyt.runners.minibatch_rl import MinibatchRl
from rlpyt.utils.logging.context import logger_context
from rlpyt.utils.launching.variant import load_variant, update_config

#from rlpyt.experiments.configs.mujoco.qpg.mujoco_a2c import configs  # outdated
from rlpyt.experiments.configs.mujoco.qpg.mujoco_ddpg import configs

def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    # Outdated
    #sampler = CpuParallelSampler(
    #    EnvCls=gym_make,
    #    env_kwargs=config["env"],
    #    CollectorCls=ResetCollector,
    #    **config["sampler"]
    #)
    # Replacement.
    sampler = CpuSampler(
        EnvCls=gym_make,
        env_kwargs=config["env"],
        eval_env_kwargs=config["env"],
        CollectorCls=CpuWaitResetCollector,
        batch_T=5,
        batch_B=16,
        **config["sampler"]
    )

    algo = DDPG(optim_kwargs=config["optim"], **config["algo"])
    agent = DdpgAgent(**config["agent"])
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["id"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()

if __name__ == "__main__":
    build_and_train(*sys.argv[1:])

Note that these require changes into the config to create some new dictionaries:

https://github.com/astooke/rlpyt/blob/master/rlpyt/experiments/configs/mujoco/qpg/mujoco_ddpg.py

# Daniel: just add a bunch more keys.
config = copy.deepcopy(config)
config['sampler'].pop('batch_B')
config['sampler'].pop('batch_T')
#config['sampler']['batch_T'] = 1   # Actually let's just do this in the launch file.
#config['sampler']['batch_B'] = 10  # Same comment as above
configs['ddpg_from_td3_1M_parallel_cpu'] = config
config = copy.deepcopy(config)
configs['ddpg_from_td3_1M_parallel_gpu'] = config

A similar pattern can be done for the GPU.

DanielTakeshi commented 4 years ago

@astooke I see how the experiments directory is generally structured. Do you have some suggestions or hacks you developed along the way to improve workflow? Mainly I am wondering because it is a lot of typing and hitting TAB to get to the files deep into the repository (the launch and training files).

astooke commented 4 years ago

Yes the general workflow included three files: 1) a "run" script which builds multiple hyperparameter settings and launches a separate experiment for each one (the one script to manually call), 2) a "train" script which builds one individual experiment and trains the agent (the run script calls the train script in separate python process calls), 3) optionally a "config" file which can store all the default settings for a given group of experiments, so that only the hyperparameters being swept need to be changed in the "run" script, since the "train" script has access to the config file.

I'll keep all of those under one sub-directory in the experiments folder. Then yeah after building up many of those when working along one line of research, eventually I'll start a new folder to work inside.

Definitely not the only way to do it, but after doing another project or two using rlpyt, I've stuck with that. :)

astooke commented 4 years ago

OK there are still a bunch of extra launch scripts floating around, but they basically just correspond to different computer configurations, and I think they all run now (or have their own issue), so will leave it as is for the time being. :) Please reopen or make an issue if you see a specific one you want that is broken.

DanielTakeshi commented 4 years ago

Minor note, the Categorical DQN scripts:

https://github.com/astooke/rlpyt/blob/master/rlpyt/experiments/scripts/atari/dqn/train/atari_catdqn_gpu.py https://github.com/astooke/rlpyt/blob/master/rlpyt/experiments/scripts/atari/dqn/launch/launch_atari_catdqn_gpu_basic.py

still use the outdated paths. These can be replaced with this git diff:

-from rlpyt.samplers.gpu.sampler import GpuSampler
-from rlpyt.samplers.gpu.collectors import GpuWaitResetCollector
+from rlpyt.samplers.parallel.gpu.sampler import GpuSampler
+from rlpyt.samplers.parallel.gpu.collectors import GpuWaitResetCollector
astooke commented 4 years ago

@DanielTakeshi Thanks, just fixed that last one!