Closed Sean-Hastings closed 4 years ago
Update: trying to get around this I hit another case of what seems to be the same issue. I think it's something to do with the rlbench gym env because that's the main connection I can think of between these examples. When I run the below code with python RLBench/tools/dataset_generator.py --tasks empty_dishwasher --processes 4
it gets stuck as described in the main post.
Note: the following is a modified version of the provided dataset_generator.py
from multiprocessing import Process, Manager
from pyrep.const import RenderMode
from rlbench import ObservationConfig
from rlbench.action_modes import ActionMode
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment
import rlbench.backend.task as task
import gym
import rlbench.gym
import os
import pickle
from PIL import Image
from rlbench.backend import utils
from rlbench.backend.const import *
import numpy as np
from absl import app
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_string('save_path',
'/tmp/rlbench_data/',
'Where to save the demos.')
flags.DEFINE_list('tasks', [],
'The tasks to collect. If empty, all tasks are collected.')
flags.DEFINE_integer('processes', 1,
'The number of parallel processes during collection.')
flags.DEFINE_integer('episodes_per_task', 10,
'The number of episodes to collect per task.')
def check_and_make(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def run(i, results, env):
"""Each thread will choose one task and variation, and then gather
all the episodes_per_task for that variation."""
# Initialise each thread with different seed
tasks_with_problems = ''
np.random.seed(i)
print('Process', i, 'started collecting for task', env.task.get_name())
task_path = os.path.join(FLAGS.save_path, env.task.get_name())
check_and_make(task_path)
for ex_idx in range(FLAGS.episodes_per_task):
attempts = 10
while attempts > 0:
try:
# TODO: for now we do the explicit looping.
demo, = env.task.get_demos(
amount=1,
live_demos=True)
problem = (
'Process %d failed collecting task %s (example: %d). Skipping this task.\n%s\n' % (
i, env.task.get_name(), ex_idx, 'got to making demos')
)
print(problem)
tasks_with_problems += problem
abort_variation = True
break
except Exception as e:
attempts -= 1
if attempts > 0:
continue
problem = (
'Process %d failed collecting task %s (example: %d). Skipping this task.\n%s\n' % (
i, env.task.get_name(), ex_idx, str(e))
)
print(problem)
tasks_with_problems += problem
abort_variation = True
break
episode_path = os.path.join(task_path, str(i + FLAGS.processes*ex_idx))
for j, obs in enumerate(demo):
with open(os.path.join(episode_path, str(j)), 'wb') as f:
pickle.dump(env._extract_obs(obs), f)
if abort_variation:
break
results[i] = tasks_with_problems
def main(argv):
task_files = [t.replace('.py', '') for t in os.listdir(task.TASKS_PATH)
if t != '__init__.py' and t.endswith('.py')]
if len(FLAGS.tasks) > 0:
for t in FLAGS.tasks:
if t not in task_files:
raise ValueError('Task %s not recognised!.' % t)
task_files = FLAGS.tasks
env = [gym.make(t + '-state-v0') for t in task_files][0]
manager = Manager()
result_dict = manager.dict()
check_and_make(FLAGS.save_path)
processes = [Process(
target=run, args=(
i, result_dict, env))
for i in range(FLAGS.processes)]
[t.start() for t in processes]
[t.join() for t in processes]
print('Data collection done!')
for i in range(FLAGS.processes):
print(result_dict[i])
if __name__ == '__main__':
app.run(main)
Hi. I know the issue:
You are calling gym.make
(which in-turn launches RLBench) outside of the process. This is not allowed.
You'll notice that in the original dataset_generator.py, launch is called inside each of the processes.
All the best, Stephen
I am attempting to use RLBench in conjunction with Pytorch-RL (https://github.com/Khrylx/PyTorch-RL) and it works for single- but not multi-processed data collection.
The relevant file that I am using is as follows:
The important bit is that when num_threads is 1, and therefore only the main-thread call to collect_samples(...) occurs, everything runs great. If num_threads > 1, however, it get stuck waiting for the worker threads to put their work into the queue.
In an experiment I ran the same thing with
raise Exception(pid)
at various points in collect_samples and found that the worker threads were enteringstate = env.reset()
and then never exiting or raising any exceptions. On the system monitor it also shows that the main thread uses the cpu noticeably while working through the data collection while the worker threads never pass 0% cpu usage.