pfnet / pfrl

PFRL: a PyTorch-based deep reinforcement learning library
MIT License
1.2k stars 157 forks source link

Fix an error related to collation and RNN #129

Closed tarokiritani closed 3 years ago

tarokiritani commented 3 years ago

Recurrent policies expect data collated in a tuple.

muupan commented 3 years ago

/test

pfn-ci-bot commented 3 years ago

Successfully created a job for commit 671b6d6:

muupan commented 3 years ago

It seems the test fails with gpu. Can you apply this patch so it will work with both cpu and gpu? Looks good except this failure.

diff --git a/tests/utils_tests/test_batch_states.py b/tests/utils_tests/test_batch_states.py
index ba4391b3..cdfedcd8 100644
--- a/tests/utils_tests/test_batch_states.py
+++ b/tests/utils_tests/test_batch_states.py
@@ -27,7 +27,7 @@ class TestBatchStates(unittest.TestCase):
         self.assertIsInstance(batch, tuple)
         batch_a, batch_b, batch_c = batch
         np.testing.assert_allclose(
-            batch_a,
+            batch_a.cpu(),
             np.asarray(
                 [
                     [[0, 2], [4, 6]],
@@ -35,9 +35,9 @@ class TestBatchStates(unittest.TestCase):
                 ]
             ),
         )
-        np.testing.assert_allclose(batch_b, np.asarray([0, 1]))
+        np.testing.assert_allclose(batch_b.cpu(), np.asarray([0, 1]))
         np.testing.assert_allclose(
-            batch_c,
+            batch_c.cpu(),
             np.asarray(
                 [
                     [0],
muupan commented 3 years ago

/test

pfn-ci-bot commented 3 years ago

Successfully created a job for commit b5f6907: