gehring / fax

MIT License
78 stars 9 forks source link

fix cg import #7

Closed niklasschmitz closed 4 years ago

niklasschmitz commented 4 years ago

fixes the cg import in cg_test.py

running the test results in an error (very close to the default tolerance):

python cg_test.py

Running tests under Python 3.7.4: /home/niku/Misc/anaconda3/bin/python
[ RUN      ] CGTest.testSolveSimpleCase
/home/niku/Misc/anaconda3/lib/python3.7/site-packages/jax/lib/xla_bridge.py:119: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Falsifying example: testSolveSimpleCase(
    self=<__main__.CGTest testMethod=testSolveSimpleCase>,
    amat=array([[0.1, 0.1, 0.1],
           [0.1, 0.1, 0.1],
           [0.1, 0.1, 0.1]]),
    bvec=array([0.1, 0.1, 0.1]),
)
[  FAILED  ] CGTest.testSolveSimpleCase
[ RUN      ] CGTest.testTupleSolveSimpleCase
[       OK ] CGTest.testTupleSolveSimpleCase
======================================================================
FAIL: testSolveSimpleCase (__main__.CGTest)
testSolveSimpleCase (__main__.CGTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "cg_test.py", line 19, in testSolveSimpleCase
    @hypothesis.given(
  File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/hypothesis/core.py", line 1081, in wrapped_test
    raise the_error_hypothesis_found
  File "cg_test.py", line 34, in testSolveSimpleCase
    check_dtypes=True)
  File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/jax/test_util.py", line 685, in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol)
  File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/jax/test_util.py", line 658, in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol)
  File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/jax/test_util.py", line 111, in _assert_numpy_allclose
    onp.testing.assert_allclose(a, b, **kw)
  File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1515, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)
  File "/home/niku/Misc/anaconda3/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 841, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-15, atol=1e-15

Mismatch: 33.3%
Max absolute difference: 1.08940634e-15
Max relative difference: 1.84109672e-14
 x: array([0.059172, 0.059172, 0.059172])
 y: array([0.059172, 0.059172, 0.059172])

----------------------------------------------------------------------
Ran 2 tests in 59.195s

FAILED (failures=1)
niklasschmitz commented 4 years ago

I really like the matrix-free implementation in conjugate_gradient_solve, which is right what I was looking for to use with jax; thanks!

pierrelux commented 4 years ago

Thank you!