chainer / chainerrl

ChainerRL is a deep reinforcement learning library built on top of Chainer.
MIT License
1.18k stars 224 forks source link

TestTrainAgentAsync.test is flaky #578

Open muupan opened 5 years ago

muupan commented 5 years ago
=================================== FAILURES ===================================
_____ TestTrainAgentAsync_param_1_{max_episode_len=None, num_envs=2}.test ______
self = <chainer.testing._bundle.TestTrainAgentAsync_param_1_{max_episode_len=None, num_envs=2} testMethod=test>
    def test(self):

        steps = 50

        outdir = tempfile.mkdtemp()

        agent = mock.Mock()
        agent.shared_attributes = []

        def _make_env(process_idx, test):
            env = mock.Mock()
            env.reset.side_effect = [('state', 0)] * 1000
            if self.max_episode_len is None:
                # Episodic env that terminates after 5 actions
                env.step.side_effect = [
                    (('state', 1), 0, False, {}),
                    (('state', 2), 0, False, {}),
                    (('state', 3), -0.5, False, {}),
                    (('state', 4), 0, False, {}),
                    (('state', 5), 1, True, {}),
                ] * 1000
            else:
                # Continuing env
                env.step.side_effect = [
                    (('state', 1), 0, False, {}),
                ] * 1000
            return env

        # Keep references to mock envs to check their states later
        envs = [_make_env(i, test=False) for i in range(self.num_envs)]
        eval_envs = [_make_env(i, test=True) for i in range(self.num_envs)]

        def make_env(process_idx, test):
            if test:
                return eval_envs[process_idx]
            else:
                return envs[process_idx]

        # Mock states cannot be shared among processes. To check states of mock
        # objects, threading is used instead of multiprocessing.
        # Because threading.Thread does not have .exitcode attribute, we
        # add the attribute manually to avoid an exception.
        import threading

        # Mock.call_args_list does not seem thread-safe
        hook_lock = threading.Lock()
        hook = mock.Mock()

        def hook_locked(*args, **kwargs):
            with hook_lock:
                return hook(*args, **kwargs)

        with mock.patch('multiprocessing.Process', threading.Thread),\
            mock.patch.object(
                threading.Thread, 'exitcode', create=True, new=0):
            chainerrl.experiments.train_agent_async(
                processes=self.num_envs,
                agent=agent,
                make_env=make_env,
                steps=steps,
                outdir=outdir,
                max_episode_len=self.max_episode_len,
                global_step_hooks=[hook_locked],
            )

        if self.num_envs == 1:
            self.assertEqual(agent.act_and_train.call_count, steps)
        elif self.num_envs > 1:
>           self.assertGreater(agent.act_and_train.call_count, steps)
E           AssertionError: 50 not greater than 50
tests/experiments_tests/test_train_agent_async.py:94: AssertionError