diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix index 12a69358b45..4901467262f 100644 --- a/pkgs/development/python-modules/jax/default.nix +++ b/pkgs/development/python-modules/jax/default.nix @@ -89,7 +89,7 @@ buildPythonPackage rec { "test_custom_linear_solve_cholesky" "test_custom_root_with_aux" "testEigvalsGrad_shape" - ] ++ lib.optionals (stdenv.isAarch64 && stdenv.isDarwin) [ + ] ++ lib.optionals stdenv.isAarch64 [ # See https://github.com/google/jax/issues/14793. "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" "testQdwhWithRandomMatrix3" @@ -107,6 +107,9 @@ buildPythonPackage rec { "tests/nn_test.py" "tests/random_test.py" "tests/sparse_test.py" + ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ + # RuntimeWarning: invalid value encountered in cast + "tests/lax_test.py" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed.