Martin Weinelt
836e3af544
python3Packages.jax: disable test_custom_linear_solve_aux
...
```
______________ CustomLinearSolveTest.test_custom_linear_solve_aux ______________
[gw3] linux -- Python 3.9.11 /nix/store/k1physzalj5vffsvl7ag6h6b6vaqip5x-python3-3.9.11/bin/python3.9
self = <custom_linear_solve_test.CustomLinearSolveTest testMethod=test_custom_linear_solve_aux>
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_linear_solve_aux(self):
def explicit_jacobian_solve_aux(matvec, b):
x = lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))
return x, array_aux
def matrix_free_solve_aux(matvec, b):
return lax.custom_linear_solve(
matvec, b, explicit_jacobian_solve_aux, explicit_jacobian_solve_aux,
symmetric=True, has_aux=True)
def linear_solve_aux(a, b):
return matrix_free_solve_aux(partial(high_precision_dot, a), b)
# array aux values, to be able to use jtu.check_grads
array_aux = {"converged": np.array(1.), "nfev": np.array(12345.)}
rng = self.rng()
a = rng.randn(3, 3)
a = a + a.T
b = rng.randn(3)
expected = jnp.linalg.solve(a, b)
actual_nojit, nojit_aux = linear_solve_aux(a, b)
actual_jit, jit_aux = jax.jit(linear_solve_aux)(a, b)
self.assertAllClose(expected, actual_nojit)
self.assertAllClose(expected, actual_jit)
# scalar dict equality check
self.assertDictEqual(nojit_aux, array_aux)
self.assertDictEqual(jit_aux, array_aux)
# jvp / vjp test
> jtu.check_grads(linear_solve_aux, (a, b), order=2, rtol=4e-3)
tests/custom_linear_solve_test.py:157:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jax/_src/test_util.py:372: in check_grads
_check_grads(f, args, order)
jax/_src/test_util.py:361: in _check_grads
_check_grads(partial(api.jvp, f), (args, args), order - 1, fwd_msg)
jax/_src/test_util.py:365: in _check_grads
_check_vjp(f, partial(api.vjp, f), args, err_msg=rev_msg)
jax/_src/test_util.py:325: in check_vjp
check_close(ip, ip_expected, atol=atol, rtol=rtol,
jax/_src/test_util.py:227: in check_close
tree_all(tree_multimap(assert_close, xs, ys))
jax/_src/tree_util.py:180: in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
jax/_src/tree_util.py:180: in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
jax/_src/test_util.py:217: in _assert_numpy_close
_assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size,
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = array(1.89683694), b = array(1.88698006), atol = 0.002, rtol = 0.004
err_msg = 'VJP of JVP cotangent projection'
def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
if a.dtype == b.dtype == _dtypes.float0:
np.testing.assert_array_equal(a, b, err_msg=err_msg)
return
a = a.astype(np.float32) if a.dtype == _dtypes.bfloat16 else a
b = b.astype(np.float32) if b.dtype == _dtypes.bfloat16 else b
kw = {}
if atol: kw["atol"] = atol
if rtol: kw["rtol"] = rtol
with np.errstate(invalid='ignore'):
# TODO(phawkins): surprisingly, assert_allclose sometimes reports invalid
# value errors. It should not do that.
> np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
E AssertionError:
E Not equal to tolerance rtol=0.004, atol=0.002
E VJP of JVP cotangent projection
E Mismatched elements: 1 / 1 (100%)
E Max absolute difference: 0.00985688
E Max relative difference: 0.00522363
E x: array(1.896837)
E y: array(1.88698)
jax/_src/test_util.py:187: AssertionError
```
2022-04-15 01:39:54 +02:00