python3Packages.jax: support MKL BLAS/LAPACK implementations

This commit is contained in:
Samuel Ainsworth 2022-02-25 21:51:29 +00:00 committed by Frederik Rietdijk
parent b398f196e6
commit 33984cd89c

View file

@ -1,8 +1,10 @@
{ lib
, absl-py
, blas
, buildPythonPackage
, fetchFromGitHub
, jaxlib
, lapack
, numpy
, opt-einsum
, pytestCheckHook
@ -12,6 +14,9 @@
, typing-extensions
}:
let
usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
in
buildPythonPackage rec {
pname = "jax";
version = "0.3.1";
@ -59,6 +64,9 @@ buildPythonPackage rec {
"tests/"
];
# See https://github.com/google/jax/issues/9705.
disabledTests = lib.optionals usingMKL [ "test_custom_root_with_aux" ];
pythonImportsCheck = [
"jax"
];