comp-imaging / ProxImaL

A domain-specific language for image optimization.
MIT License
112 stars 29 forks source link

ladmm variable `v` shape must match that of `omega_fns` #62

Closed antonysigma closed 2 years ago

antonysigma commented 2 years ago

This new test case exposes a bug in the linearized_admm algorithm, in which the update of variable v failed to match the input dimensions of the omega_fn.

This is how to trigger the exception with pytest:

$ python3 -m pytest proximal/tests/test_algs.py

============================================= test session starts ==============================================
platform linux -- Python 3.6.9, pytest-5.4.3, py-1.9.0, pluggy-0.13.1
rootdir: /home/antony/Documents/ProxImaL
plugins: odl-0.7.0, cov-2.10.1
collected 11 items                                                                                             

tests/test_algs.py ...s.....F.                                                                           [100%]

=================================================== FAILURES ===================================================
______________________________ TestAlgs.test_lin_admm_two_prox_fn_shape_matching _______________________________

self = <proximal.tests.test_algs.TestAlgs object at 0x7f7faf22dcc0>

    def test_lin_admm_two_prox_fn_shape_matching(self):
        # With nested linear operators.
        kernel_mat = np.array([[2, 1, 3], [3, 2, 1], [1, 3, 2]])

        N = 3
        factor = 1
        b = np.ones((N, N))
        x = px.Variable((N * factor, N * factor))
        prox_fns = [
            px.norm1(x),
            px.sum_squares(px.subsample(px.conv(kernel_mat, x),
                                        (factor, factor)),
                           b=b),
        ]
        psi_fns, omega_fns = ladmm.partition(prox_fns)
        sltn = ladmm.solve(psi_fns,
                           omega_fns,
                           0.1,
                           max_iters=3000,
                           eps_abs=1e-5,
>                          eps_rel=1e-5)

tests/test_algs.py:426: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
algorithms/linearized_admm.py:101: in solve
    lin_solver=lin_solver, options=lin_solver_options)
prox_fns/prox_fn.py:98: in prox
    xhat = self._prox(rho_hat, v, *args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <proximal.prox_fns.norm1.norm1 object at 0x7f7faf261c18>, rho = array(3240.0)
v = array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]), args = ()
kwargs = {'lin_solver': 'cg', 'options': None, 'x_init': array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])}

    def _prox(self, rho, v, *args, **kwargs):
        """x = sign(v)*(|v| - 1/rho)_+
        """

        if self.implementation == Impl['halide'] and (len(self.lin_op.shape) in [2, 3, 4]):
            # Halide implementation
            Halide('prox_L1').prox_L1(v, 1. / rho, self.tmpout)
            np.copyto(v, self.tmpout)

        else:
            # Numpy implementation
>           np.sign(v, self.v_sign)
E           ValueError: operands could not be broadcast together with shapes (9,) (3,3)

prox_fns/norm1.py:28: ValueError
=========================================== short test summary info ============================================
FAILED tests/test_algs.py::TestAlgs::test_lin_admm_two_prox_fn_shape_matching - ValueError: operands could no...
=================================== 1 failed, 9 passed, 1 skipped in 16.25s ====================================

Resolve #63 .

SteveDiamond commented 2 years ago

It looks ok. If you get all the tests passing then it's fine.

SteveDiamond commented 2 years ago

I see there are some issues with the CI. You need to fix those before we can merge anything.

antonysigma commented 2 years ago

Thanks @SteveDiamond for reviewing the changes. The CI complains about the master branch, not my changes though. Sure, I will look into it, and submit a separate merge request to resolve the CI.

Update: this could be related to the CI failure. https://stackoverflow.com/a/69100830