stepjam / RLBench

A large-scale benchmark and learning environment.
https://sites.google.com/corp/view/rlbench
Other
1.16k stars 235 forks source link

env.reset() getting stuck in worker processes #74

Closed Sean-Hastings closed 4 years ago

Sean-Hastings commented 4 years ago

I am attempting to use RLBench in conjunction with Pytorch-RL (https://github.com/Khrylx/PyTorch-RL) and it works for single- but not multi-processed data collection.

The relevant file that I am using is as follows:

import multiprocessing
from utils.replay_memory import Memory
from utils.torch import *
import math
import time

def collect_samples(pid, queue, env, policy, custom_reward,
                    mean_action, render, running_state, min_batch_size):
    try:
        torch.randn(pid)
        log = dict()
        memory = Memory()
        num_steps = 0
        total_reward = 0
        min_reward = 1e6
        max_reward = -1e6
        total_c_reward = 0
        min_c_reward = 1e6
        max_c_reward = -1e6
        num_episodes = 0

        while num_steps < min_batch_size:
            state = env.reset()
            if running_state is not None:
                state = running_state(state)
            reward_episode = 0

            for t in range(600):
                print(str(t) + '                    ', end='\r')
                state_var = tensor(state).unsqueeze(0)
                with torch.no_grad():
                    if mean_action:
                        action = policy(state_var)[0][0].numpy()
                    else:
                        action = policy.select_action(state_var)[0].numpy()
                action = int(action) if policy.is_disc_action else action.astype(np.float64)
                next_state, reward, done, _ = env.step(action)
                reward_episode += reward
                if running_state is not None:
                    next_state = running_state(next_state)

                if custom_reward is not None:
                    reward = custom_reward(state, action)
                    total_c_reward += reward
                    min_c_reward = min(min_c_reward, reward)
                    max_c_reward = max(max_c_reward, reward)

                mask = 0 if done else 1

                memory.push(state, action, mask, next_state, reward)

                if render:
                    env.render()
                if done:
                    break

                state = next_state

            # log stats
            num_steps += (t + 1)
            num_episodes += 1
            total_reward += reward_episode
            min_reward = min(min_reward, reward_episode)
            max_reward = max(max_reward, reward_episode)

        log['num_steps'] = num_steps
        log['num_episodes'] = num_episodes
        log['total_reward'] = total_reward
        log['avg_reward'] = total_reward / num_episodes
        log['max_reward'] = max_reward
        log['min_reward'] = min_reward
        if custom_reward is not None:
            log['total_c_reward'] = total_c_reward
            log['avg_c_reward'] = total_c_reward / num_steps
            log['max_c_reward'] = max_c_reward
            log['min_c_reward'] = min_c_reward

        if queue is not None:
            queue.put([pid, memory, log])
        else:
            return memory, log
    except Exception as e:
        if queue is not None:
            queue.put([pid, memory, log])
        else:
            raise e

def merge_log(log_list):
    log = dict()
    log['total_reward'] = sum([x['total_reward'] for x in log_list])
    log['num_episodes'] = sum([x['num_episodes'] for x in log_list])
    log['num_steps'] = sum([x['num_steps'] for x in log_list])
    log['avg_reward'] = log['total_reward'] / log['num_episodes']
    log['max_reward'] = max([x['max_reward'] for x in log_list])
    log['min_reward'] = min([x['min_reward'] for x in log_list])
    if 'total_c_reward' in log_list[0]:
        log['total_c_reward'] = sum([x['total_c_reward'] for x in log_list])
        log['avg_c_reward'] = log['total_c_reward'] / log['num_steps']
        log['max_c_reward'] = max([x['max_c_reward'] for x in log_list])
        log['min_c_reward'] = min([x['min_c_reward'] for x in log_list])

    return log

class Agent:

    def __init__(self, env, policy, device, custom_reward=None,
                 mean_action=False, render=False, running_state=None, num_threads=1):
        self.env = env
        self.policy = policy
        self.device = device
        self.custom_reward = custom_reward
        self.mean_action = mean_action
        self.running_state = running_state
        self.render = render
        self.num_threads = num_threads

    def collect_samples(self, min_batch_size):
        t_start = time.time()
        to_device(torch.device('cpu'), self.policy)
        thread_batch_size = int(math.floor(min_batch_size / self.num_threads))
        queue = multiprocessing.Queue()
        workers = []

        for i in range(1, self.num_threads):
            worker_args = (i, queue, self.env, self.policy, self.custom_reward, self.mean_action,
                           False, self.running_state, thread_batch_size)
            workers.append(multiprocessing.Process(target=collect_samples, args=worker_args))
        for worker in workers:
            worker.start()

        memory, log = collect_samples(0, None, self.env, self.policy, self.custom_reward, self.mean_action,
                                      self.render, self.running_state, thread_batch_size)

        worker_logs = [None] * len(workers)
        worker_memories = [None] * len(workers)
        print('compiling workers')
        for _ in workers:
            print('getting worker')
            e = queue.get()
            print('got worker')
            try:
                pid, worker_memory, worker_log = e
                worker_memories[pid - 1] = worker_memory
                worker_logs[pid - 1] = worker_log
            except Exception:
                print(e)
        print('appending memories')
        for worker_memory in worker_memories:
            memory.append(worker_memory)
        print('sampling memories')
        batch = memory.sample()
        print('merging logs')
        if self.num_threads > 1:
            log_list = [log] + worker_logs
            log = merge_log(log_list)
        to_device(self.device, self.policy)
        t_end = time.time()
        print('building stats')
        log['sample_time'] = t_end - t_start
        log['action_mean'] = np.mean(np.vstack(batch.action), axis=0)
        log['action_min'] = np.min(np.vstack(batch.action), axis=0)
        log['action_max'] = np.max(np.vstack(batch.action), axis=0)
        return batch, log

The important bit is that when num_threads is 1, and therefore only the main-thread call to collect_samples(...) occurs, everything runs great. If num_threads > 1, however, it get stuck waiting for the worker threads to put their work into the queue.

In an experiment I ran the same thing with raise Exception(pid) at various points in collect_samples and found that the worker threads were entering state = env.reset() and then never exiting or raising any exceptions. On the system monitor it also shows that the main thread uses the cpu noticeably while working through the data collection while the worker threads never pass 0% cpu usage.

Sean-Hastings commented 4 years ago

Update: trying to get around this I hit another case of what seems to be the same issue. I think it's something to do with the rlbench gym env because that's the main connection I can think of between these examples. When I run the below code with python RLBench/tools/dataset_generator.py --tasks empty_dishwasher --processes 4 it gets stuck as described in the main post.

Note: the following is a modified version of the provided dataset_generator.py

from multiprocessing import Process, Manager

from pyrep.const import RenderMode

from rlbench import ObservationConfig
from rlbench.action_modes import ActionMode
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment
import rlbench.backend.task as task
import gym
import rlbench.gym

import os
import pickle
from PIL import Image
from rlbench.backend import utils
from rlbench.backend.const import *
import numpy as np

from absl import app
from absl import flags

FLAGS = flags.FLAGS

flags.DEFINE_string('save_path',
                    '/tmp/rlbench_data/',
                    'Where to save the demos.')
flags.DEFINE_list('tasks', [],
                  'The tasks to collect. If empty, all tasks are collected.')
flags.DEFINE_integer('processes', 1,
                     'The number of parallel processes during collection.')
flags.DEFINE_integer('episodes_per_task', 10,
                     'The number of episodes to collect per task.')

def check_and_make(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)

def run(i, results, env):
    """Each thread will choose one task and variation, and then gather
    all the episodes_per_task for that variation."""

    # Initialise each thread with different seed
    tasks_with_problems = ''
    np.random.seed(i)

    print('Process', i, 'started collecting for task', env.task.get_name())

    task_path = os.path.join(FLAGS.save_path, env.task.get_name())
    check_and_make(task_path)

    for ex_idx in range(FLAGS.episodes_per_task):
        attempts = 10
        while attempts > 0:
            try:
                # TODO: for now we do the explicit looping.
                demo, = env.task.get_demos(
                    amount=1,
                    live_demos=True)
                problem = (
                    'Process %d failed collecting task %s (example: %d). Skipping this task.\n%s\n' % (
                        i, env.task.get_name(), ex_idx, 'got to making demos')
                )
                print(problem)
                tasks_with_problems += problem
                abort_variation = True
                break
            except Exception as e:
                attempts -= 1
                if attempts > 0:
                    continue
                problem = (
                    'Process %d failed collecting task %s (example: %d). Skipping this task.\n%s\n' % (
                        i, env.task.get_name(), ex_idx, str(e))
                )
                print(problem)
                tasks_with_problems += problem
                abort_variation = True
                break
            episode_path = os.path.join(task_path, str(i + FLAGS.processes*ex_idx))
            for j, obs in enumerate(demo):
                with open(os.path.join(episode_path, str(j)), 'wb') as f:
                    pickle.dump(env._extract_obs(obs), f)
        if abort_variation:
            break

    results[i] = tasks_with_problems

def main(argv):

    task_files = [t.replace('.py', '') for t in os.listdir(task.TASKS_PATH)
                  if t != '__init__.py' and t.endswith('.py')]

    if len(FLAGS.tasks) > 0:
        for t in FLAGS.tasks:
            if t not in task_files:
                raise ValueError('Task %s not recognised!.' % t)
        task_files = FLAGS.tasks

    env = [gym.make(t + '-state-v0') for t in task_files][0]

    manager = Manager()

    result_dict = manager.dict()

    check_and_make(FLAGS.save_path)

    processes = [Process(
        target=run, args=(
            i, result_dict, env))
        for i in range(FLAGS.processes)]
    [t.start() for t in processes]
    [t.join() for t in processes]

    print('Data collection done!')
    for i in range(FLAGS.processes):
        print(result_dict[i])

if __name__ == '__main__':
  app.run(main)
stepjam commented 4 years ago

Hi. I know the issue: You are calling gym.make (which in-turn launches RLBench) outside of the process. This is not allowed. You'll notice that in the original dataset_generator.py, launch is called inside each of the processes.

All the best, Stephen