data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
MIT License
69 stars 22 forks source link

Some PyTorch fixes #140

Closed asmeurer closed 3 months ago

rgommers commented 3 months ago

Thanks @asmeurer. That all seems reasonable. For context: does it fix any issues, or CI or test suite failures visible somewhere?

asmeurer commented 3 months ago

Yes. They don't show up on CI yet because I need to finish fixing the tests for pytorch But these were the failures:

================================================================================ FAILURES =================================================================================
_______________________________________________________________________________ test_solve ________________________________________________________________________________

>   @pytest.mark.xp_extension('linalg')

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

x1 = tensor([[[1.]]]), x2 = tensor([[0.]])

    def test_solve(x1, x2):
        res = linalg.solve(x1, x2)

        ph.assert_dtype("solve", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype)
        if x2.ndim == 1:
            expected_shape = x1.shape[:-2] + x2.shape[-1:]
            _test_stacks(linalg.solve, x1, x2, res=res, dims=1,
                         matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
            stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
            expected_shape = stack_shape + x2.shape[-2:]
            _test_stacks(linalg.solve, x1, x2, res=res, dims=2)

>       ph.assert_result_shape("solve", in_shapes=[x1.shape, x2.shape],
                               out_shape=res.shape, expected=expected_shape)
E       AssertionError: out.shape=torch.Size([1, 1]), but should be (1, 1, 1) [solve( torch.Size([1, 1, 1]) . torch.Size([1, 1]) )]
E       Falsifying example: test_solve(
E           x1=tensor([[[1.]]]),
E           x2=tensor([[0.]]),
E       )

array_api_tests/ AssertionError
____________________________________________________________________________ test_vector_norm _____________________________________________________________________________

>   @pytest.mark.xp_extension('linalg')

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

x = tensor([0.+0.j, 0.+0.j]), data = data(...)

        x=arrays(dtype=all_floating_dtypes(), shape=shapes(min_side=1)),
    def test_vector_norm(x, data):
        kw = data.draw(
            # We use data because axes is parameterized on x.ndim
                       sampled_from([2, 1, 0, -1, -2, float("inf"), float("-inf")]),
                       integers(-max_ord, max_ord),
                       floats(-max_ord, max_ord),
                   )), label="kw")

        res = linalg.vector_norm(x, **kw)
        axis = kw.get('axis', None)
        keepdims = kw.get('keepdims', False)
        # TODO: Check that the ord values give the correct norms.
        # ord = kw.get('ord', 2)

        _axes = sh.normalise_axis(axis, x.ndim)

>       ph.assert_keepdimable_shape('linalg.vector_norm', out_shape=res.shape,
                                    in_shape=x.shape, axes=_axes,
                                    keepdims=keepdims, kw=kw)
E       AssertionError: out.shape=torch.Size([1]), but should be (2,) [linalg.vector_norm(axis=())]
E       Falsifying example: test_vector_norm(
E           x=tensor([0.+0.j, 0.+0.j]),
E           data=data(...),
E       )
E       Draw 1 (kw): {'axis': ()}

array_api_tests/ AssertionError
========================================================================= short test summary info =========================================================================
FAILED array_api_tests/ - AssertionError: out.shape=torch.Size([1, 1]), but should be (1, 1, 1) [solve( torch.Size([1, 1, 1]) . torch.Size([1, 1]) )]
FAILED array_api_tests/ - AssertionError: out.shape=torch.Size([1]), but should be (2,) [linalg.vector_norm(axis=())]

I'm not completely sure why these only showed up now. Maybe something changed in a recent pytorch version, or else something was broken in the tests.