alexhernandezgarcia / gflownet

Generative Flow Networks - GFlowNet
https://gflownet.readthedocs.io/en/latest/
Apache License 2.0
161 stars 10 forks source link

Refactor common environment tests #242

Closed alexhernandezgarcia closed 7 months ago

alexhernandezgarcia commented 11 months ago

The way the common tests are implemented has several issues. Importantly, the repetition decorators are ignored.

See this comment: https://github.com/alexhernandezgarcia/gflownet/pull/204/files#r1339115115

josephdviviano commented 9 months ago

OK I'm starting to look at this, in the dumbest way possible:

N = 2

def test__all_env_common(env):
    for _ in range(N):
        test__init__state_is_source_no_parents(env)
        ...
        test__gflownet_minimal_runs(env)

Any value of N > 1 raises an assertion error. I'm not sure if that's just because the env state needs to be reset between iterations or not.

However a bigger problem arises:

$ pytest -x test_crystal.py

...

../../../../../miniconda3/envs/gflownet/lib/python3.10/site-packages/pyxtal/symmetry.py:10
  /Users/jdv/miniconda3/envs/gflownet/lib/python3.10/site-packages/pyxtal/symmetry.py:10: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
    from pkg_resources import resource_filename as rf

test_crystal.py::test__all_env_common
  /Users/jdv/code/gflownet/tests/gflownet/envs/common.py:140: UserWarning: Skipping test because states are None.
    warnings.warn("Skipping test because states are None.")

test_crystal.py::test__all_env_common
  /Users/jdv/code/gflownet/tests/gflownet/envs/common.py:156: UserWarning: Skipping test because states are None.
    warnings.warn("Skipping test because states are None.")

test_crystal.py::test__all_env_common
  /Users/jdv/code/gflownet/tests/gflownet/envs/common.py:178: UserWarning: Skipping test because states are None.
    warnings.warn("Skipping test because states are None.")

test_crystal.py::test__all_env_common
  /Users/jdv/code/gflownet/tests/gflownet/envs/common.py:436: UserWarning: Skipping test for this specific environment.
    warnings.warn("Skipping test for this specific environment.")

test_crystal.py::test__all_env_common
  /Users/jdv/code/gflownet/tests/gflownet/envs/common.py:398: UserWarning: Skipping test for this specific environment.
    warnings.warn("Skipping test for this specific environment.")

test_crystal.py::test__all_env_common
  /Users/jdv/code/gflownet/tests/gflownet/envs/common.py:205: UserWarning: Skipping test because states are None.
    warnings.warn("Skipping test because states are None.")

test_crystal.py::test__all_env_common
  /Users/jdv/code/gflownet/gflownet/utils/buffer.py:136: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
    self.main = pd.concat(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================== short test summary info =======================================
FAILED test_crystal.py::test__all_env_common - AssertionError
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
============================== 1 failed, 28 passed, 8 warnings in 2.55s ==============================
(gflownet) jdv@delarge ~/code/gflownet/tests/gflownet/envs:

The test output is both very noisy, and when a test fails, it's not clear why. Maybe we should brainstorm options in our next standup :)

josephdviviano commented 9 months ago

I think env.reset() fixes the assertion error. So the question is only not about test legibility.

alexhernandezgarcia commented 9 months ago

Yes, env.reset() seems necessary. Thanks for looking into this. Let's definitely discuss this tomorrow! I hope to have a moment to check it beforehand.

josephdviviano commented 9 months ago

Im running the tests now. They occasionally fail, but it also takes a really long time, I’m not sure if this is a practical solution (if we really need 500 tests)

Joseph (Mobile)

On Tue, Dec 5, 2023 at 10:35 Alex @.***> wrote:

Yes, env.reset() seems necessary. Thanks for looking into this. Let's definitely discuss this tomorrow! I hope to have a moment to check it beforehand.

— Reply to this email directly, view it on GitHub https://github.com/alexhernandezgarcia/gflownet/issues/242#issuecomment-1841036368, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA7TL2UGDDGXVMZDXFLFK5TYH45KVAVCNFSM6AAAAAA6H3ZU4GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNBRGAZTMMZWHA . You are receiving this because you were assigned.Message ID: @.***>

josephdviviano commented 9 months ago

OK, I has to kill the run. I think the runtime is more like 4 hours (not the 8 reported) due to my computer being in sleep mode for some of the day, but it's clear we need to reduce the runtime by ~100x.

Some of the tests with the repeats eventually fail, interestingly. I'm copying the output below so we can discuss:

(gflownet) jdv@delarge ~/code/gflownet/tests/gflownet/envs:
pytest .
======================================== test session starts =========================================
platform darwin -- Python 3.10.13, pytest-7.4.3, pluggy-1.3.0
rootdir: /Users/jdv/code/gflownet/tests/gflownet/envs
plugins: torchtyping-0.1.4, repeat-0.9.3, anyio-4.1.0, typeguard-4.1.5, hydra-core-1.3.2
collected 1611 items                                                                                 

test_ccrystal.py ............................................................................. [  4%]
......................sss..................................................................... [ 10%]
........................................................F
.                                     [ 14%]
test_ccube.py ................................................................................ [ 19%]
......................sssssssssss..                                                                                    [ 21%]
test_clattice_parameters.py .................................................................. [ 25%]
.............................................................................................. [ 31%]
.............................................................................................. [ 37%]
.............................................................................................. [ 42%]
.............................................................................................. [ 48%]
.............................................................................................. [ 54%]
.............................................................................................. [ 60%]
........................................................................................ssssss [ 66%]
ssssssss.....F
.                                                                                     [ 67%]
test_composition.py .......................................................................... [ 71%]
......................................................^C

================================================ FAILURES =================================================
_______________________________________ test__continuous_env_common _______________________________________

env_sg_first = <gflownet.envs.crystals.ccrystal.CCrystal object at 0x178e9af20>

    def test__continuous_env_common(env_sg_first):
        print("\n\nCommon tests for crystal with space group first\n")
>       return common.test__continuous_env_common(env_sg_first)

test_ccrystal.py:1606: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
common.py:48: in test__continuous_env_common
    test__backward_actions_have_nonzero_forward_prob(env)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

env = <gflownet.envs.crystals.ccrystal.CCrystal object at 0x178e9af20>, n = 1000

    def test__backward_actions_have_nonzero_forward_prob(env, n=1000):
        # Skip for certain environments until fixed:
        skip_envs = ["Crystal", "LatticeParameters"]
        if env.__class__.__name__ in skip_envs:
            warnings.warn("Skipping test for this specific environment.")
            return
        states = _get_terminating_states(env, n)
        if states is None:
            warnings.warn("Skipping test because states are None.")
            return
        policy_random = torch.unsqueeze(env.random_policy_output, 0)
        for state in states:
            env.set_state(state, done=True)
            while True:
                if env.equal(env.state, env.source):
                    break
                state_next, action, valid = env.step_random(backward=True)
                assert valid
                # Get forward logprobs
                mask_fw = env.get_mask_invalid_actions_forward()
                masks = torch.unsqueeze(tbool(mask_fw, device=env.device), 0)
                actions_torch = torch.unsqueeze(
                    tfloat(action, float_type=env.float, device=env.device), 0
                )
                policy_outputs = policy_random.clone().detach()
                logprobs_fw = env.get_logprobs(
                    policy_outputs=policy_outputs,
                    actions=actions_torch,
                    mask=masks,
                    states_from=[env.state],
                    is_backward=False,
                )
>               assert torch.isfinite(logprobs_fw)
E               AssertionError

common.py:467: AssertionError
------------------------------------------ Captured stdout call -------------------------------------------

Common tests for crystal with space group first

_________________________________ test__continuous_env_common[tetragonal] _________________________________

env = <gflownet.envs.crystals.clattice_parameters.CLatticeParameters object at 0x179049420>
lattice_system = 'tetragonal'

    @pytest.mark.parametrize(
        "lattice_system",
        [CUBIC, HEXAGONAL, MONOCLINIC, ORTHORHOMBIC, RHOMBOHEDRAL, TETRAGONAL, TRICLINIC],
    )
    def test__continuous_env_common(env, lattice_system):
>       return common.test__continuous_env_common(env)

test_clattice_parameters.py:304: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
common.py:48: in test__continuous_env_common
    test__backward_actions_have_nonzero_forward_prob(env)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

env = <gflownet.envs.crystals.clattice_parameters.CLatticeParameters object at 0x179049420>, n = 1000

        )
        assert actions_trajectory_fw == actions_trajectory_bw[::-1]

    def test__backward_actions_have_nonzero_forward_prob(env, n=1000):
        # Skip for certain environments until fixed:
        skip_envs = ["Crystal", "LatticeParameters"]
        if env.__class__.__name__ in skip_envs:
            warnings.warn("Skipping test for this specific environment.")
            return
        states = _get_terminating_states(env, n)
        if states is None:
            warnings.warn("Skipping test because states are None.")
            return
        policy_random = torch.unsqueeze(env.random_policy_output, 0)
        for state in states:
            env.set_state(state, done=True)
>           while True:
E           AssertionError

common.py:452: AssertionError
============================================ warnings summary =============================================
../../../../../miniconda3/envs/gflownet/lib/python3.10/site-packages/pyxtal/symmetry.py:10
  /Users/jdv/miniconda3/envs/gflownet/lib/python3.10/site-packages/pyxtal/symmetry.py:10: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
    from pkg_resources import resource_filename as rf

test_clattice_parameters.py::test__continuous_env_common[tetragonal]
  /Users/jdv/code/gflownet/gflownet/envs/cube.py:1330: UserWarning: 
                  State is out of cube bounds.

  Current state:
  [0.10000000826517741, 0.1820143616994222, 0.15523421696821849, 0.12497658960024516, 0.10097558728853862, 0.14860373667875926]
  Action:
  (0.10000000894069672, 0.13387973606586456, 0.12324856966733932, 0.11009382456541061, 0.10047899186611176, 0.1254289448261261, 0.0)
  Next state: [-6.755193071583676e-10, 0.04813462563355764, 0.03198564730087916, 0.014882765034834544, 0.0004965954224268598, 0.02317479185263316]

    warnings.warn(

test_composition.py: 956 warnings
  /Users/jdv/code/gflownet/gflownet/utils/buffer.py:136: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
    self.main = pd.concat(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================= short test summary info =========================================
FAILED test_ccrystal.py::test__continuous_env_common - AssertionError
FAILED test_clattice_parameters.py::test__continuous_env_common[tetragonal] - AssertionError
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! KeyboardInterrupt !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
/Users/jdv/miniconda3/envs/gflownet/lib/python3.10/site-packages/torch/distributions/distribution.py:58: KeyboardInterrupt
(to show a full traceback on KeyboardInterrupt use --full-trace)
================= 2 failed, 1181 passed, 28 skipped, 958 warnings in 30739.55s (8:32:19) ==================
alexhernandezgarcia commented 9 months ago

This seems useful: https://stackoverflow.com/questions/50016862/grouping-tests-in-pytest-classes-vs-plain-functions

josephdviviano commented 8 months ago

So I've tried a few configurations, and I think this is the cleanest solution that does what we need without getting too complex:

import pytest
import inspect

def get_current_method_name():
    """Helper method to get name of the current method."""
    return inspect.currentframe().f_back.f_code.co_name

class Env:
    def __init__(self, value):
        self.value = value
        self.state = "000"

@pytest.fixture
def env():
    return Env(0)

class BaseTestClass:
    def test_common_1(self, n_repeat=1):
        if get_current_method_name() in self.repeats:
            n_repeat = self.repeats[get_current_method_name()]

        for _ in range(n_repeat):
            print("test 1")
            self.env.value += 1  # Note the env is persistent across repeats.
            print(self.env.value)

    def test_common_2(self, n_repeat=1):
        if get_current_method_name() in self.repeats:
            n_repeat = self.repeats[get_current_method_name()]

        for _ in range(n_repeat):
            print("test 2")
            self.env.state += "+". # Note the env is persistent across repeats.
            print(self.env.state)

class TestSpecificInstance1(BaseTestClass):
    @pytest.fixture(autouse=True)
    def setup(self, env):
        self.env = env  # This sets up the env ONCE at the beginning. State management
                        # must happen before each test is run, if required. 
        self.repeats = {
            "test_common_2": 3,   # Sets repeat for self.test_common_2.
        }

class TestSpecificInstance2(BaseTestClass):
    @pytest.fixture(autouse=True)
    def setup(self, env):
        self.env = env. # This sets up the env ONCE at the beginning.
        self.repeats = {
            "test_common_1": 3,   # Sets repeat for self.test_common_1.
        }

Essentially, to control the granularity of the tests, you just need to define a dict for each environment-specific test class. None of the tests themselves need to be overwritten. This also allows for states to be handled on a per-repeat basis, per test, which might be required for good tests.

It gives the desired results below (note the env state / value is modified correctly).

(gflownet) jdv@delarge ~/code/gflownet/tests/gflownet:
pytest test_test.py -s
============================== test session starts ==============================
platform darwin -- Python 3.10.13, pytest-7.4.3, pluggy-1.3.0
rootdir: /Users/jdv/code/gflownet/tests/gflownet
plugins: torchtyping-0.1.4, repeat-0.9.3, anyio-4.1.0, typeguard-4.1.5, hydra-core-1.3.2
collected 4 items                                                               

test_test.py 

test 1
1
test 2
000+
test 2
000++
test 2
000+++
test 1
1
test 1
2
test 1
3
test 2
000+

Let me know what you think of this general design and I'll implement it for a single env, push it as a draft PR, before overhauling all of them.

@alexhernandezgarcia @carriepl

alexhernandezgarcia commented 8 months ago

For the records, this seems pretty neat and I'd be happy to move with the plan! Thanks!

josephdviviano commented 8 months ago

I've pushed a proposed solution here -- to be reviewed before I do this for all the environments.

https://github.com/alexhernandezgarcia/gflownet/pull/267/commits/5f4345d9a347c75d5cdbc7f609299c8ff9eb5372

josephdviviano commented 7 months ago

This is handled in the aforementioned PR.