Closed alexhernandezgarcia closed 7 months ago
OK I'm starting to look at this, in the dumbest way possible:
N = 2
def test__all_env_common(env):
for _ in range(N):
test__init__state_is_source_no_parents(env)
...
test__gflownet_minimal_runs(env)
Any value of N > 1
raises an assertion error. I'm not sure if that's just because the env
state needs to be reset between iterations or not.
However a bigger problem arises:
$ pytest -x test_crystal.py
...
../../../../../miniconda3/envs/gflownet/lib/python3.10/site-packages/pyxtal/symmetry.py:10
/Users/jdv/miniconda3/envs/gflownet/lib/python3.10/site-packages/pyxtal/symmetry.py:10: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
from pkg_resources import resource_filename as rf
test_crystal.py::test__all_env_common
/Users/jdv/code/gflownet/tests/gflownet/envs/common.py:140: UserWarning: Skipping test because states are None.
warnings.warn("Skipping test because states are None.")
test_crystal.py::test__all_env_common
/Users/jdv/code/gflownet/tests/gflownet/envs/common.py:156: UserWarning: Skipping test because states are None.
warnings.warn("Skipping test because states are None.")
test_crystal.py::test__all_env_common
/Users/jdv/code/gflownet/tests/gflownet/envs/common.py:178: UserWarning: Skipping test because states are None.
warnings.warn("Skipping test because states are None.")
test_crystal.py::test__all_env_common
/Users/jdv/code/gflownet/tests/gflownet/envs/common.py:436: UserWarning: Skipping test for this specific environment.
warnings.warn("Skipping test for this specific environment.")
test_crystal.py::test__all_env_common
/Users/jdv/code/gflownet/tests/gflownet/envs/common.py:398: UserWarning: Skipping test for this specific environment.
warnings.warn("Skipping test for this specific environment.")
test_crystal.py::test__all_env_common
/Users/jdv/code/gflownet/tests/gflownet/envs/common.py:205: UserWarning: Skipping test because states are None.
warnings.warn("Skipping test because states are None.")
test_crystal.py::test__all_env_common
/Users/jdv/code/gflownet/gflownet/utils/buffer.py:136: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
self.main = pd.concat(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================== short test summary info =======================================
FAILED test_crystal.py::test__all_env_common - AssertionError
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
============================== 1 failed, 28 passed, 8 warnings in 2.55s ==============================
(gflownet) jdv@delarge ~/code/gflownet/tests/gflownet/envs:
The test output is both very noisy, and when a test fails, it's not clear why. Maybe we should brainstorm options in our next standup :)
I think env.reset()
fixes the assertion error. So the question is only not about test legibility.
Yes, env.reset()
seems necessary. Thanks for looking into this. Let's definitely discuss this tomorrow! I hope to have a moment to check it beforehand.
Im running the tests now. They occasionally fail, but it also takes a really long time, I’m not sure if this is a practical solution (if we really need 500 tests)
Joseph (Mobile)
On Tue, Dec 5, 2023 at 10:35 Alex @.***> wrote:
Yes, env.reset() seems necessary. Thanks for looking into this. Let's definitely discuss this tomorrow! I hope to have a moment to check it beforehand.
— Reply to this email directly, view it on GitHub https://github.com/alexhernandezgarcia/gflownet/issues/242#issuecomment-1841036368, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA7TL2UGDDGXVMZDXFLFK5TYH45KVAVCNFSM6AAAAAA6H3ZU4GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNBRGAZTMMZWHA . You are receiving this because you were assigned.Message ID: @.***>
OK, I has to kill the run. I think the runtime is more like 4 hours (not the 8 reported) due to my computer being in sleep mode for some of the day, but it's clear we need to reduce the runtime by ~100x.
Some of the tests with the repeats eventually fail, interestingly. I'm copying the output below so we can discuss:
(gflownet) jdv@delarge ~/code/gflownet/tests/gflownet/envs:
pytest .
======================================== test session starts =========================================
platform darwin -- Python 3.10.13, pytest-7.4.3, pluggy-1.3.0
rootdir: /Users/jdv/code/gflownet/tests/gflownet/envs
plugins: torchtyping-0.1.4, repeat-0.9.3, anyio-4.1.0, typeguard-4.1.5, hydra-core-1.3.2
collected 1611 items
test_ccrystal.py ............................................................................. [ 4%]
......................sss..................................................................... [ 10%]
........................................................F
. [ 14%]
test_ccube.py ................................................................................ [ 19%]
......................sssssssssss.. [ 21%]
test_clattice_parameters.py .................................................................. [ 25%]
.............................................................................................. [ 31%]
.............................................................................................. [ 37%]
.............................................................................................. [ 42%]
.............................................................................................. [ 48%]
.............................................................................................. [ 54%]
.............................................................................................. [ 60%]
........................................................................................ssssss [ 66%]
ssssssss.....F
. [ 67%]
test_composition.py .......................................................................... [ 71%]
......................................................^C
================================================ FAILURES =================================================
_______________________________________ test__continuous_env_common _______________________________________
env_sg_first = <gflownet.envs.crystals.ccrystal.CCrystal object at 0x178e9af20>
def test__continuous_env_common(env_sg_first):
print("\n\nCommon tests for crystal with space group first\n")
> return common.test__continuous_env_common(env_sg_first)
test_ccrystal.py:1606:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
common.py:48: in test__continuous_env_common
test__backward_actions_have_nonzero_forward_prob(env)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
env = <gflownet.envs.crystals.ccrystal.CCrystal object at 0x178e9af20>, n = 1000
def test__backward_actions_have_nonzero_forward_prob(env, n=1000):
# Skip for certain environments until fixed:
skip_envs = ["Crystal", "LatticeParameters"]
if env.__class__.__name__ in skip_envs:
warnings.warn("Skipping test for this specific environment.")
return
states = _get_terminating_states(env, n)
if states is None:
warnings.warn("Skipping test because states are None.")
return
policy_random = torch.unsqueeze(env.random_policy_output, 0)
for state in states:
env.set_state(state, done=True)
while True:
if env.equal(env.state, env.source):
break
state_next, action, valid = env.step_random(backward=True)
assert valid
# Get forward logprobs
mask_fw = env.get_mask_invalid_actions_forward()
masks = torch.unsqueeze(tbool(mask_fw, device=env.device), 0)
actions_torch = torch.unsqueeze(
tfloat(action, float_type=env.float, device=env.device), 0
)
policy_outputs = policy_random.clone().detach()
logprobs_fw = env.get_logprobs(
policy_outputs=policy_outputs,
actions=actions_torch,
mask=masks,
states_from=[env.state],
is_backward=False,
)
> assert torch.isfinite(logprobs_fw)
E AssertionError
common.py:467: AssertionError
------------------------------------------ Captured stdout call -------------------------------------------
Common tests for crystal with space group first
_________________________________ test__continuous_env_common[tetragonal] _________________________________
env = <gflownet.envs.crystals.clattice_parameters.CLatticeParameters object at 0x179049420>
lattice_system = 'tetragonal'
@pytest.mark.parametrize(
"lattice_system",
[CUBIC, HEXAGONAL, MONOCLINIC, ORTHORHOMBIC, RHOMBOHEDRAL, TETRAGONAL, TRICLINIC],
)
def test__continuous_env_common(env, lattice_system):
> return common.test__continuous_env_common(env)
test_clattice_parameters.py:304:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
common.py:48: in test__continuous_env_common
test__backward_actions_have_nonzero_forward_prob(env)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
env = <gflownet.envs.crystals.clattice_parameters.CLatticeParameters object at 0x179049420>, n = 1000
)
assert actions_trajectory_fw == actions_trajectory_bw[::-1]
def test__backward_actions_have_nonzero_forward_prob(env, n=1000):
# Skip for certain environments until fixed:
skip_envs = ["Crystal", "LatticeParameters"]
if env.__class__.__name__ in skip_envs:
warnings.warn("Skipping test for this specific environment.")
return
states = _get_terminating_states(env, n)
if states is None:
warnings.warn("Skipping test because states are None.")
return
policy_random = torch.unsqueeze(env.random_policy_output, 0)
for state in states:
env.set_state(state, done=True)
> while True:
E AssertionError
common.py:452: AssertionError
============================================ warnings summary =============================================
../../../../../miniconda3/envs/gflownet/lib/python3.10/site-packages/pyxtal/symmetry.py:10
/Users/jdv/miniconda3/envs/gflownet/lib/python3.10/site-packages/pyxtal/symmetry.py:10: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
from pkg_resources import resource_filename as rf
test_clattice_parameters.py::test__continuous_env_common[tetragonal]
/Users/jdv/code/gflownet/gflownet/envs/cube.py:1330: UserWarning:
State is out of cube bounds.
Current state:
[0.10000000826517741, 0.1820143616994222, 0.15523421696821849, 0.12497658960024516, 0.10097558728853862, 0.14860373667875926]
Action:
(0.10000000894069672, 0.13387973606586456, 0.12324856966733932, 0.11009382456541061, 0.10047899186611176, 0.1254289448261261, 0.0)
Next state: [-6.755193071583676e-10, 0.04813462563355764, 0.03198564730087916, 0.014882765034834544, 0.0004965954224268598, 0.02317479185263316]
warnings.warn(
test_composition.py: 956 warnings
/Users/jdv/code/gflownet/gflownet/utils/buffer.py:136: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
self.main = pd.concat(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================= short test summary info =========================================
FAILED test_ccrystal.py::test__continuous_env_common - AssertionError
FAILED test_clattice_parameters.py::test__continuous_env_common[tetragonal] - AssertionError
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! KeyboardInterrupt !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
/Users/jdv/miniconda3/envs/gflownet/lib/python3.10/site-packages/torch/distributions/distribution.py:58: KeyboardInterrupt
(to show a full traceback on KeyboardInterrupt use --full-trace)
================= 2 failed, 1181 passed, 28 skipped, 958 warnings in 30739.55s (8:32:19) ==================
So I've tried a few configurations, and I think this is the cleanest solution that does what we need without getting too complex:
import pytest
import inspect
def get_current_method_name():
"""Helper method to get name of the current method."""
return inspect.currentframe().f_back.f_code.co_name
class Env:
def __init__(self, value):
self.value = value
self.state = "000"
@pytest.fixture
def env():
return Env(0)
class BaseTestClass:
def test_common_1(self, n_repeat=1):
if get_current_method_name() in self.repeats:
n_repeat = self.repeats[get_current_method_name()]
for _ in range(n_repeat):
print("test 1")
self.env.value += 1 # Note the env is persistent across repeats.
print(self.env.value)
def test_common_2(self, n_repeat=1):
if get_current_method_name() in self.repeats:
n_repeat = self.repeats[get_current_method_name()]
for _ in range(n_repeat):
print("test 2")
self.env.state += "+". # Note the env is persistent across repeats.
print(self.env.state)
class TestSpecificInstance1(BaseTestClass):
@pytest.fixture(autouse=True)
def setup(self, env):
self.env = env # This sets up the env ONCE at the beginning. State management
# must happen before each test is run, if required.
self.repeats = {
"test_common_2": 3, # Sets repeat for self.test_common_2.
}
class TestSpecificInstance2(BaseTestClass):
@pytest.fixture(autouse=True)
def setup(self, env):
self.env = env. # This sets up the env ONCE at the beginning.
self.repeats = {
"test_common_1": 3, # Sets repeat for self.test_common_1.
}
Essentially, to control the granularity of the tests, you just need to define a dict for each environment-specific test class. None of the tests themselves need to be overwritten. This also allows for states to be handled on a per-repeat basis, per test, which might be required for good tests.
It gives the desired results below (note the env state / value is modified correctly).
(gflownet) jdv@delarge ~/code/gflownet/tests/gflownet:
pytest test_test.py -s
============================== test session starts ==============================
platform darwin -- Python 3.10.13, pytest-7.4.3, pluggy-1.3.0
rootdir: /Users/jdv/code/gflownet/tests/gflownet
plugins: torchtyping-0.1.4, repeat-0.9.3, anyio-4.1.0, typeguard-4.1.5, hydra-core-1.3.2
collected 4 items
test_test.py
test 1
1
test 2
000+
test 2
000++
test 2
000+++
test 1
1
test 1
2
test 1
3
test 2
000+
Let me know what you think of this general design and I'll implement it for a single env, push it as a draft PR, before overhauling all of them.
@alexhernandezgarcia @carriepl
For the records, this seems pretty neat and I'd be happy to move with the plan! Thanks!
I've pushed a proposed solution here -- to be reviewed before I do this for all the environments.
This is handled in the aforementioned PR.
The way the common tests are implemented has several issues. Importantly, the repetition decorators are ignored.
See this comment: https://github.com/alexhernandezgarcia/gflownet/pull/204/files#r1339115115