openai / baselines

OpenAI Baselines: high-quality implementations of reinforcement learning algorithms
MIT License
15.65k stars 4.86k forks source link

HER MPI broadcasting issue with non-Reach environments #843

Open vitchyr opened 5 years ago

vitchyr commented 5 years ago

The following command runs fine:

time mpirun -np 8 python -m baselines.run --num_env=2 --alg=her --env=FetchReach-v1 --num_timesteps=100000 

However, if I try changing the environment to the FetchPush-v1 or FetchPickAndPlace-v1, I get the following error: When trying to run multiple MPI threads

Training...
Traceback (most recent call last):
  File "/home/vitchyr/anaconda2/envs/baselines2/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/vitchyr/anaconda2/envs/baselines2/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/vitchyr/git/baselines/baselines/run.py", line 246, in <module>
    main(sys.argv)
  File "/home/vitchyr/git/baselines/baselines/run.py", line 210, in main
    model, env = train(args, extra_args)
  File "/home/vitchyr/git/baselines/baselines/run.py", line 79, in train
    **alg_kwargs
  File "/home/vitchyr/git/baselines/baselines/her/her.py", line 181, in learn
    policy_save_interval=policy_save_interval, demo_file=demo_file)
  File "/home/vitchyr/git/baselines/baselines/her/her.py", line 59, in train
    logger.record_tabular(key, mpi_average(val))
  File "/home/vitchyr/git/baselines/baselines/her/her.py", line 20, in mpi_average
    return mpi_moments(np.array(value))[0]
  File "/home/vitchyr/git/baselines/baselines/common/mpi_moments.py", line 22, in mpi_moments
    mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True)
  File "/home/vitchyr/git/baselines/baselines/common/mpi_moments.py", line 16, in mpi_mean
    comm.Allreduce(localsum, globalsum, op=MPI.SUM)
  File "mpi4py/MPI/Comm.pyx", line 714, in mpi4py.MPI.Comm.Allreduce

These different environments work for me if I run them without MPI.

I am using anaconda. My Python version is 3.6.2 and this is the output of pip freeze:

absl-py==0.7.0
astor==0.7.1
baselines==0.1.5
certifi==2018.11.29
cffi==1.12.2
chardet==3.0.4
Click==7.0
cloudpickle==0.8.0
Cython==0.29.6
dill==0.2.9
future==0.17.1
gast==0.2.2
glfw==1.7.1
grpcio==1.19.0
gym==0.12.0
h5py==2.9.0
idna==2.8
imageio==2.5.0
joblib==0.13.2
Keras-Applications==1.0.7
Keras-Preprocessing==1.0.9
lockfile==0.12.2
Markdown==3.0.1
mock==2.0.0
mpi4py==3.0.1
mujoco-py==1.50.1.59
numpy==1.16.2
opencv-python==4.0.0.21
pbr==5.1.3
Pillow==5.4.1
progressbar2==3.39.2
protobuf==3.7.0
pycparser==2.19
pyglet==1.3.2
python-utils==2.3.0
requests==2.21.0
scipy==1.2.1
six==1.12.0
tensorboard==1.13.0
tensorflow==1.13.1
tensorflow-estimator==1.13.0
termcolor==1.1.0
tqdm==4.31.1
urllib3==1.24.1
Werkzeug==0.14.1
gfsliumin commented 5 years ago

Hi, I have the same problem, have you sovled this?

keshaviyengar commented 5 years ago

I run into this error as well.

krishpop commented 5 years ago

Same here, but for the HandManipulatePen env, though it prints out the first log table just fine.

keshaviyengar commented 5 years ago

I've temporarily fixed the issue in my custom environment by having the success rate over 100 rather than 1 (ie. multiplying the success rate by 100 in rollout.py). The issue lies somewhere in mpi_moments.py when calculating the meansqdiff with non-zero success rate with multiple cpus. Hope this is useful.

wecacuee commented 5 years ago

It seems like an issue related to x.dtype. Probably workers are for some reason generated a different x.dtype and MPI is not able to reconcile them. A workaround is to force the dtype of localsum in line 12 to be np.float64. The right fix would be figure out why different dtypes.

jangirrishabh commented 5 years ago

The above work around works for me thanks

wxplovehlt commented 5 years ago

time mpirun -np 8 meams num-cpu=8? and if i want use mpi,the command "mpirun "must be added ?

keshaviyengar commented 5 years ago

time mpirun -np 8 meams num-cpu=8? and if i want use mpi,the command "mpirun "must be added ?

Yes, that's how it works. Check out the mpi documentation about mpirun, but you need to add it to use mpi.