osqp / osqp-python

Python interface for OSQP
https://osqp.org/
Apache License 2.0
109 stars 41 forks source link

Update unit test test_dl_dq_nonzero_dy #156

Open vineetbansal opened 2 months ago

vineetbansal commented 2 months ago

Since the api for osqp_adjoint_derivative_compute has changed (it only takes a y now instead of y_low and y_high), the wrapper definitions have changed accordingly, as well as the tests. osqp-python is thus pinned to osqp ff3716 (current as of this writing).

One test that I could not fix after the wrapper change was test_dl_dq_nonzero_dy (original implementation below), so I've disabled it with a leading underscore. Someone needs to go in and enable this, after modifying it to reflect whatever was being tested for before.

 def test_dl_dq_nonzero_dy(self, verbose=False):
        n, m = 6, 3

        prob = self.get_prob(n=n, m=m, P_scale=1.0, A_scale=1.0)
        P, q, A, l, u, true_x, true_yl, true_yu = prob
        # u = l
        # l[20:40] = -osqp.constant('OSQP_INFTY', algebra='builtin')
        num_eq = 2
        u[:num_eq] = l[:num_eq]

        def grad(q, mode):
            _, dq, _, _, _ = self.get_grads(P, q, A, l, u, true_x, true_yl, true_yu, mode=mode)
            return dq

        def f(q):
            m = osqp.OSQP(algebra='builtin')
            m.setup(
                P,
                q,
                A,
                l,
                u,
                eps_abs=eps_abs,
                eps_rel=eps_rel,
                max_iter=max_iter,
                verbose=False,
            )
            res = m.solve()
            if res.info.status != 'solved':
                raise ValueError('Problem not solved!')
            x_hat = res.x
            y_hat = res.y
            yu_hat = np.maximum(y_hat, 0)
            yl_hat = -np.minimum(y_hat, 0)

            # return 0.5 * np.sum(np.square(x_hat - true_x)) + np.sum(yl_hat) + np.sum(yu_hat)
            return 0.5 * (
                np.sum(np.square(x_hat - true_x))
                + np.sum(np.square(yl_hat - true_yl))
                + np.sum(np.square(yu_hat - true_yu))
            )

        dq_qdldl = grad(q, 'qdldl')
        dq_fd = approx_fprime(q, f, grad_precision)

        if verbose:
            print('dq_fd: ', np.round(dq_fd, decimals=4))
            print('dq_qdldl: ', np.round(dq_qdldl, decimals=4))

        npt.assert_allclose(dq_fd, dq_qdldl, rtol=rel_tol, atol=abs_tol)