pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

test_sample_partial test failure #609

Open Apteryks opened 10 months ago

Apteryks commented 10 months ago

Hello,

I've seen this in GNU Guix CI:


=================================== FAILURES ===================================
___________________________ test_sample_partial[()] ____________________________

int_inputs = OrderedDict()

    @pytest.mark.parametrize(
        "int_inputs",
        [
            OrderedDict(),
            OrderedDict([("i", Bint[2])]),
            OrderedDict([("i", Bint[2]), ("j", Bint[3])]),
        ],
        ids=id_from_inputs,
    )
    def test_sample_partial(int_inputs):
        int_inputs = OrderedDict(sorted(int_inputs.items()))
        real_inputs = OrderedDict(
            [("w", Reals[2]), ("x", Reals[4]), ("y", Reals[2, 3]), ("z", Real)]
        )
        inputs = int_inputs.copy()
        inputs.update(real_inputs)
        flat = ops.cat(
            [Variable(k, d).reshape((d.num_elements,)) for k, d in real_inputs.items()]
        )

        def compute_moments(samples):
            flat_samples = flat(**extract_samples(samples))
            assert set(flat_samples.inputs) == {"particle"} | set(int_inputs)
            mean = flat_samples.reduce(ops.mean)
            diff = flat_samples - mean
            cov = (diff[:, None] - diff[None, :]).reduce(ops.mean)
            return mean, cov

        sample_inputs = OrderedDict(particle=Bint[50000])
        rng_keys = [None] * 3
        if get_backend() == "jax":
            import jax.random

            rng_keys = jax.random.split(np.array([0, 0], dtype=np.uint32), 3)

        g = random_gaussian(inputs)
        all_vars = frozenset("wxyz")
        samples = g.sample(all_vars, sample_inputs, rng_keys[0])
        expected_mean, expected_cov = compute_moments(samples)
        subsets = "w x y z wx wy wz xy xz yz wxy wxz wyz xyz".split()
        for sampled_vars in map(frozenset, subsets):
            g2 = g.sample(sampled_vars, sample_inputs, rng_keys[1])
            samples = g2.sample(all_vars, sample_inputs, rng_keys[2])
            actual_mean, actual_cov = compute_moments(samples)
>           assert_close(actual_mean, expected_mean, atol=1e-1, rtol=1e-1)

test/test_gaussian.py:843: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/gnu/store/zmr2dvs06mazgfnrxh2b9f5lvxxs1ylz-python-funsor-0.4.5/lib/python3.10/site-packages/funsor/testing.py:136: in assert_close
    assert_close(actual.data, expected.data, atol=atol, rtol=rtol)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

actual = array([-1.00111884, -0.68092102,  0.69566841, -0.46268585, -1.07358358,
       -0.24324944, -1.22855264,  0.21857192, -0.04263261,  1.09376803,
        0.77507805, -0.01998961, -1.81863148])
expected = array([-0.99112969, -0.68268335,  0.69722302, -0.45905905, -1.08039513,
       -0.24329676, -1.23492357,  0.2196864 , -0.04553655,  1.09680053,
        0.77216647, -0.035942  , -1.81212992])
atol = 0.1, rtol = 0.1

    def assert_close(actual, expected, atol=1e-6, rtol=1e-6):
        msg = ActualExpected(actual, expected)
        if is_array(actual):
            assert is_array(expected), msg
        elif isinstance(actual, Tensor) and is_array(actual.data):
            assert isinstance(expected, Tensor) and is_array(expected.data), msg
        elif (
            isinstance(actual, Contraction)
            and isinstance(actual.terms[0], Tensor)
            and is_array(actual.terms[0].data)
        ):
            assert isinstance(expected, Contraction) and is_array(
                expected.terms[0].data
            ), msg
        elif isinstance(actual, Contraction) and isinstance(actual.terms[0], Delta):
            assert isinstance(expected, Contraction) and isinstance(
                expected.terms[0], Delta
            ), msg
        elif isinstance(actual, Gaussian):
            assert isinstance(expected, Gaussian)
        else:
            assert type(actual) == type(expected), msg

        if isinstance(actual, Funsor):
            assert isinstance(expected, Funsor), msg
            assert actual.inputs == expected.inputs, (actual.inputs, expected.inputs)
            assert actual.output == expected.output, (actual.output, expected.output)

        if isinstance(actual, (Number, Tensor)):
            assert_close(actual.data, expected.data, atol=atol, rtol=rtol)
        elif isinstance(actual, Delta):
            assert frozenset(n for n, p in actual.terms) == frozenset(
                n for n, p in expected.terms
            )
            actual = actual.align(tuple(n for n, p in expected.terms))
            for (actual_name, (actual_point, actual_log_density)), (
                expected_name,
                (expected_point, expected_log_density),
            ) in zip(actual.terms, expected.terms):
                assert actual_name == expected_name
                assert_close(actual_point, expected_point, atol=atol, rtol=rtol)
                assert_close(actual_log_density, expected_log_density, atol=atol, rtol=rtol)
        elif isinstance(actual, Gaussian):
            # Note white_vec and prec_sqrt are expected to agree only up to an
            # orthogonal factor, but precision and info_vec should agree exactly.
            assert_close(actual._info_vec, expected._info_vec, atol=atol, rtol=rtol)
            assert_close(actual._precision, expected._precision, atol=atol, rtol=rtol)
        elif isinstance(actual, Contraction):
            assert actual.red_op == expected.red_op
            assert actual.bin_op == expected.bin_op
            assert actual.reduced_vars == expected.reduced_vars
            assert len(actual.terms) == len(expected.terms)
            for ta, te in zip(actual.terms, expected.terms):
                assert_close(ta, te, atol, rtol)
        elif type(actual).__name__ == "Tensor":
            assert get_backend() == "torch"
            import torch

            assert actual.dtype == expected.dtype, msg
            assert actual.shape == expected.shape, msg
            if actual.dtype in (torch.long, torch.uint8, torch.bool):
                assert (actual == expected).all(), msg
            else:
                eq = actual == expected
                if eq.all():
                    return
                if eq.any():
                    actual = actual[~eq]
                    expected = expected[~eq]
                diff = (actual.detach() - expected.detach()).abs()
                if rtol is not None:
                    assert (diff / (atol + expected.detach().abs())).max() < rtol, msg
                elif atol is not None:
                    assert diff.max() < atol, msg
        elif is_array(actual):
            if get_backend() == "jax":
                import jax

                assert jax.numpy.result_type(actual.dtype) == jax.numpy.result_type(
                    expected.dtype
                ), msg
            else:
                assert actual.dtype == expected.dtype, msg

            assert actual.shape == expected.shape, msg
            if actual.dtype in (np.int32, np.int64, np.uint8, bool):
                assert (actual == expected).all(), msg
            else:
                actual, expected = np.asarray(actual), np.asarray(expected)
                eq = actual == expected
                if eq.all():
                    return
                if eq.any():
                    actual = actual[~eq]
                    expected = expected[~eq]
                diff = abs(actual - expected)
                if rtol is not None:
>                   assert (diff / (atol + abs(expected))).max() < rtol, msg
E                   AssertionError: Expected:
E                   [-0.99112969 -0.68268335  0.69722302 -0.45905905 -1.08039513 -0.24329676
E                    -1.23492357  0.2196864  -0.04553655  1.09680053  0.77216647 -0.035942
E                    -1.81212992]
E                   Actual:
E                   [-1.00111884 -0.68092102  0.69566841 -0.46268585 -1.07358358 -0.24324944
E                    -1.22855264  0.21857192 -0.04263261  1.09376803  0.77507805 -0.01998961
E                    -1.81863148]

/gnu/store/zmr2dvs06mazgfnrxh2b9f5lvxxs1ylz-python-funsor-0.4.5/lib/python3.10/site-packages/funsor/testing.py:204: AssertionError
=========================== short test summary info ============================
FAILED test/test_gaussian.py::test_sample_partial[()] - AssertionError: Expec...
= 1 failed, 7637 passed, 3859 skipped, 69 xfailed, 2 xpassed in 417.82s (0:06:57) =
error: in phase 'check': uncaught exception:
%exception #<&invoke-error program: "/gnu/store/m8li9l31vqfl7f3m4zmdqykc5madv2hr-python-pytest-7.1.3/bin/pytest" arguments: ("-vv") exit-status: 1 term-signal: #f stop-signal: #f> 
phase `check' failed after 419.9 seconds
command "/gnu/store/m8li9l31vqfl7f3m4zmdqykc5madv2hr-python-pytest-7.1.3/bin/pytest" "-vv" failed with status 1
builder for `/gnu/store/z2mzsj06by0ffly2b8magk1vn1ics0dc-python-funsor-0.4.5.drv' failed with exit code 1
@ build-failed /gnu/store/z2mzsj06by0ffly2b8magk1vn1ics0dc-python-funsor-0.4.5.drv - 1 builder for `/gnu/store/z2mzsj06by0ffly2b8magk1vn1ics0dc-python-funsor-0.4.5.drv' failed with exit code 1

It seems the tolerances used for the test need to be relaxed.