Commit graph

22 commits

Author SHA1 Message Date
Samuel Ainsworth 56d7a75ad7
Merge pull request #168452 from samuela/upkeep-bot/python3Packages.jax-0.3.6-1649833784
python3Packages.jax: 0.3.5 -> 0.3.6
2022-04-17 12:33:59 -07:00
Samuel Ainsworth 06d2cfa7b4
Merge pull request #168707 from samuela/samuela/jax
python3Packages.jax: fix build when testing against jaxlibWithCuda
2022-04-17 12:31:12 -07:00
nixpkgs-upkeep-bot 1cd3aa8ff5 python3Packages.jax: 0.3.5 -> 0.3.6 2022-04-16 15:15:50 +00:00
Samuel Ainsworth cb7b514703 python3Packages.jax: fix build when nixpkgs-wide cudaSupport is enabled 2022-04-16 02:23:12 +00:00
Martin Weinelt 33425fdc96
Merge pull request #166489 from NixOS/python-updates 2022-04-15 03:47:30 +02:00
Martin Weinelt 836e3af544 python3Packages.jax: disable test_custom_linear_solve_aux
```
______________ CustomLinearSolveTest.test_custom_linear_solve_aux ______________
[gw3] linux -- Python 3.9.11 /nix/store/k1physzalj5vffsvl7ag6h6b6vaqip5x-python3-3.9.11/bin/python3.9

self = <custom_linear_solve_test.CustomLinearSolveTest testMethod=test_custom_linear_solve_aux>

    @jtu.skip_on_flag("jax_skip_slow_tests", True)
    def test_custom_linear_solve_aux(self):
      def explicit_jacobian_solve_aux(matvec, b):
        x = lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))
        return x, array_aux

      def matrix_free_solve_aux(matvec, b):
        return lax.custom_linear_solve(
          matvec, b, explicit_jacobian_solve_aux, explicit_jacobian_solve_aux,
          symmetric=True, has_aux=True)

      def linear_solve_aux(a, b):
        return matrix_free_solve_aux(partial(high_precision_dot, a), b)

      # array aux values, to be able to use jtu.check_grads
      array_aux = {"converged": np.array(1.), "nfev": np.array(12345.)}
      rng = self.rng()
      a = rng.randn(3, 3)
      a = a + a.T
      b = rng.randn(3)

      expected = jnp.linalg.solve(a, b)
      actual_nojit, nojit_aux = linear_solve_aux(a, b)
      actual_jit, jit_aux = jax.jit(linear_solve_aux)(a, b)

      self.assertAllClose(expected, actual_nojit)
      self.assertAllClose(expected, actual_jit)
      # scalar dict equality check
      self.assertDictEqual(nojit_aux, array_aux)
      self.assertDictEqual(jit_aux, array_aux)

      # jvp / vjp test
>     jtu.check_grads(linear_solve_aux, (a, b), order=2, rtol=4e-3)

tests/custom_linear_solve_test.py:157:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jax/_src/test_util.py:372: in check_grads
    _check_grads(f, args, order)
jax/_src/test_util.py:361: in _check_grads
    _check_grads(partial(api.jvp, f), (args, args), order - 1, fwd_msg)
jax/_src/test_util.py:365: in _check_grads
    _check_vjp(f, partial(api.vjp, f), args, err_msg=rev_msg)
jax/_src/test_util.py:325: in check_vjp
    check_close(ip, ip_expected, atol=atol, rtol=rtol,
jax/_src/test_util.py:227: in check_close
    tree_all(tree_multimap(assert_close, xs, ys))
jax/_src/tree_util.py:180: in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
jax/_src/tree_util.py:180: in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
jax/_src/test_util.py:217: in _assert_numpy_close
    _assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size,
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

a = array(1.89683694), b = array(1.88698006), atol = 0.002, rtol = 0.004
err_msg = 'VJP of JVP cotangent projection'

    def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
      if a.dtype == b.dtype == _dtypes.float0:
        np.testing.assert_array_equal(a, b, err_msg=err_msg)
        return
      a = a.astype(np.float32) if a.dtype == _dtypes.bfloat16 else a
      b = b.astype(np.float32) if b.dtype == _dtypes.bfloat16 else b
      kw = {}
      if atol: kw["atol"] = atol
      if rtol: kw["rtol"] = rtol
      with np.errstate(invalid='ignore'):
        # TODO(phawkins): surprisingly, assert_allclose sometimes reports invalid
        # value errors. It should not do that.
>       np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
E       AssertionError:
E       Not equal to tolerance rtol=0.004, atol=0.002
E       VJP of JVP cotangent projection
E       Mismatched elements: 1 / 1 (100%)
E       Max absolute difference: 0.00985688
E       Max relative difference: 0.00522363
E        x: array(1.896837)
E        y: array(1.88698)

jax/_src/test_util.py:187: AssertionError
```
2022-04-15 01:39:54 +02:00
Martin Weinelt d57404ea3d python3Packages.jax: test with limited parallelism
The tests are prone to getting stuck with high parallelism.
2022-04-15 01:39:53 +02:00
nixpkgs-upkeep-bot bb2018c026 python3Packages.jax: 0.3.4 -> 0.3.5 2022-04-08 00:34:47 +00:00
R. Ryantm 011f9bfd39 python310Packages.jax: 0.3.3 -> 0.3.4 2022-03-19 03:15:53 +00:00
nixpkgs-upkeep-bot a05a02a88a python3Packages.jax: 0.3.1 -> 0.3.3 2022-03-18 00:37:10 +00:00
Samuel Ainsworth 069b742520 python3Packages.jax: fix MKL-enabled build on Intel CPUs
Intel doesn't know how to build a CPU it turns out. Don't use them for floating point arithmetic, even with their own MKL!
2022-02-27 08:06:53 +00:00
Samuel Ainsworth 33984cd89c python3Packages.jax: support MKL BLAS/LAPACK implementations 2022-02-26 07:33:28 +01:00
nixpkgs-upkeep-bot cef32aee43 python3Packages.jax: 0.3.0 -> 0.3.1 2022-02-19 00:42:06 +00:00
Samuel Ainsworth e9015499e0 python3Packages.jax: 0.2.28 -> 0.3.0 2022-02-14 21:40:19 +00:00
nixpkgs-upkeep-bot ee4d3df08b python3Packages.jax: 0.2.27 -> 0.2.28 2022-02-02 10:05:04 -08:00
nixpkgs-upkeep-bot 53047281eb python3Packages.jax: 0.2.26 -> 0.2.27 2022-01-19 00:47:57 +00:00
Alexander Tsvyashchenko be52722509
python3Packages.jaxlib: refactor to support Nix-based builds (#151909)
* python3Packages.jaxlib: rename to `jaxlib-bin`

Refactoring `jaxlib` to have a similar structure to `tensorflow` with the 'bin' and 'build' options.

* python3Packages.jaxlib: init the 'build' variant at 0.1.75

Similar to `tensorflow-build`, now there's an option to build `jaxlib` using Nix-provided environment and dependencies.

* python3Packages.jax: 0.2.24 -> 0.2.26

* Addressed review comments.

* Fixed `cudaSupport` missing property on some arches.

* Unified the versions of CUDA-related packages with TF.

Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
2021-12-27 16:19:10 -08:00
Jonathan Ringer 685788d164
python3Packages.jax: 0.2.24 -> 0.2.25 2021-11-30 18:51:59 -08:00
Fabian Affolter ddf89d7d96 python3Packages.jax: 0.2.21 -> 0.2.24 2021-11-09 08:43:22 -08:00
Martin Weinelt 1b07699d46 python3Packages.jax: 0.2.19 -> 0.2.21 2021-10-11 01:22:03 +02:00
Samuel Ainsworth 6f44416cf2 python3Packages.jax: remove meta.description period 2021-09-01 21:04:02 +00:00
Samuel Ainsworth 426569a041 python3Packages.jax: init at 0.2.19 2021-08-22 20:39:04 +00:00