Closed niklasschmitz closed 4 years ago
fixes the cg import in cg_test.py
running the test results in an error (very close to the default tolerance):
python cg_test.py Running tests under Python 3.7.4: /home/niku/Misc/anaconda3/bin/python [ RUN ] CGTest.testSolveSimpleCase /home/niku/Misc/anaconda3/lib/python3.7/site-packages/jax/lib/xla_bridge.py:119: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') Falsifying example: testSolveSimpleCase( self=<__main__.CGTest testMethod=testSolveSimpleCase>, amat=array([[0.1, 0.1, 0.1], [0.1, 0.1, 0.1], [0.1, 0.1, 0.1]]), bvec=array([0.1, 0.1, 0.1]), ) [ FAILED ] CGTest.testSolveSimpleCase [ RUN ] CGTest.testTupleSolveSimpleCase [ OK ] CGTest.testTupleSolveSimpleCase ====================================================================== FAIL: testSolveSimpleCase (__main__.CGTest) testSolveSimpleCase (__main__.CGTest) ---------------------------------------------------------------------- Traceback (most recent call last): File "cg_test.py", line 19, in testSolveSimpleCase @hypothesis.given( File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/hypothesis/core.py", line 1081, in wrapped_test raise the_error_hypothesis_found File "cg_test.py", line 34, in testSolveSimpleCase check_dtypes=True) File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/jax/test_util.py", line 685, in assertAllClose self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol) File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/jax/test_util.py", line 658, in assertArraysAllClose _assert_numpy_allclose(x, y, atol=atol, rtol=rtol) File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/jax/test_util.py", line 111, in _assert_numpy_allclose onp.testing.assert_allclose(a, b, **kw) File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1515, in assert_allclose verbose=verbose, header=header, equal_nan=equal_nan) File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 841, in assert_array_compare raise AssertionError(msg) AssertionError: Not equal to tolerance rtol=1e-15, atol=1e-15 Mismatch: 33.3% Max absolute difference: 1.08940634e-15 Max relative difference: 1.84109672e-14 x: array([0.059172, 0.059172, 0.059172]) y: array([0.059172, 0.059172, 0.059172]) ---------------------------------------------------------------------- Ran 2 tests in 59.195s FAILED (failures=1)
I really like the matrix-free implementation in conjugate_gradient_solve, which is right what I was looking for to use with jax; thanks!
conjugate_gradient_solve
Thank you!
fixes the cg import in cg_test.py
running the test results in an error (very close to the default tolerance):