Revert "Update JAX"
This commit is contained in:
parent
cf6c391838
commit
8423edb179
|
@ -10,12 +10,9 @@ args@{
|
||||||
, bazelFlags ? []
|
, bazelFlags ? []
|
||||||
, bazelBuildFlags ? []
|
, bazelBuildFlags ? []
|
||||||
, bazelTestFlags ? []
|
, bazelTestFlags ? []
|
||||||
, bazelRunFlags ? []
|
|
||||||
, runTargetFlags ? []
|
|
||||||
, bazelFetchFlags ? []
|
, bazelFetchFlags ? []
|
||||||
, bazelTargets ? []
|
, bazelTargets
|
||||||
, bazelTestTargets ? []
|
, bazelTestTargets ? []
|
||||||
, bazelRunTarget ? null
|
|
||||||
, buildAttrs
|
, buildAttrs
|
||||||
, fetchAttrs
|
, fetchAttrs
|
||||||
|
|
||||||
|
@ -49,23 +46,17 @@ args@{
|
||||||
|
|
||||||
let
|
let
|
||||||
fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // {
|
fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // {
|
||||||
inherit
|
name = name;
|
||||||
name
|
bazelFlags = bazelFlags;
|
||||||
bazelFlags
|
bazelBuildFlags = bazelBuildFlags;
|
||||||
bazelBuildFlags
|
bazelTestFlags = bazelTestFlags;
|
||||||
bazelTestFlags
|
bazelFetchFlags = bazelFetchFlags;
|
||||||
bazelRunFlags
|
bazelTestTargets = bazelTestTargets;
|
||||||
runTargetFlags
|
dontAddBazelOpts = dontAddBazelOpts;
|
||||||
bazelFetchFlags
|
|
||||||
bazelTargets
|
|
||||||
bazelTestTargets
|
|
||||||
bazelRunTarget
|
|
||||||
dontAddBazelOpts
|
|
||||||
;
|
|
||||||
};
|
};
|
||||||
fBuildAttrs = fArgs // buildAttrs;
|
fBuildAttrs = fArgs // buildAttrs;
|
||||||
fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ];
|
fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ];
|
||||||
bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }:
|
bazelCmd = { cmd, additionalFlags, targets }:
|
||||||
lib.optionalString (targets != [ ]) ''
|
lib.optionalString (targets != [ ]) ''
|
||||||
# See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables]
|
# See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables]
|
||||||
BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \
|
BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \
|
||||||
|
@ -82,8 +73,7 @@ let
|
||||||
"''${host_linkopts[@]}" \
|
"''${host_linkopts[@]}" \
|
||||||
$bazelFlags \
|
$bazelFlags \
|
||||||
${lib.strings.concatStringsSep " " additionalFlags} \
|
${lib.strings.concatStringsSep " " additionalFlags} \
|
||||||
${lib.strings.concatStringsSep " " targets} \
|
${lib.strings.concatStringsSep " " targets}
|
||||||
${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags}
|
|
||||||
'';
|
'';
|
||||||
# we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so:
|
# we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so:
|
||||||
# chmod: cannot operate on dangling symlink '$symlink'
|
# chmod: cannot operate on dangling symlink '$symlink'
|
||||||
|
@ -272,15 +262,6 @@ stdenv.mkDerivation (fBuildAttrs // {
|
||||||
targets = fBuildAttrs.bazelTargets;
|
targets = fBuildAttrs.bazelTargets;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
${
|
|
||||||
bazelCmd {
|
|
||||||
cmd = "run";
|
|
||||||
additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ];
|
|
||||||
# Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list.
|
|
||||||
targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ];
|
|
||||||
targetRunFlags = fBuildAttrs.runTargetFlags;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
runHook postBuild
|
runHook postBuild
|
||||||
'';
|
'';
|
||||||
})
|
})
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
, jaxlib-bin
|
, jaxlib-bin
|
||||||
, lapack
|
, lapack
|
||||||
, matplotlib
|
, matplotlib
|
||||||
, ml-dtypes
|
|
||||||
, numpy
|
, numpy
|
||||||
, opt-einsum
|
, opt-einsum
|
||||||
, pytestCheckHook
|
, pytestCheckHook
|
||||||
|
@ -28,7 +27,7 @@ let
|
||||||
in
|
in
|
||||||
buildPythonPackage rec {
|
buildPythonPackage rec {
|
||||||
pname = "jax";
|
pname = "jax";
|
||||||
version = "0.4.12";
|
version = "0.4.5";
|
||||||
format = "setuptools";
|
format = "setuptools";
|
||||||
|
|
||||||
disabled = pythonOlder "3.7";
|
disabled = pythonOlder "3.7";
|
||||||
|
@ -38,7 +37,7 @@ buildPythonPackage rec {
|
||||||
repo = pname;
|
repo = pname;
|
||||||
# google/jax contains tags for jax and jaxlib. Only use jax tags!
|
# google/jax contains tags for jax and jaxlib. Only use jax tags!
|
||||||
rev = "refs/tags/${pname}-v${version}";
|
rev = "refs/tags/${pname}-v${version}";
|
||||||
hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y=";
|
hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
|
||||||
};
|
};
|
||||||
|
|
||||||
# jaxlib is _not_ included in propagatedBuildInputs because there are
|
# jaxlib is _not_ included in propagatedBuildInputs because there are
|
||||||
|
@ -47,7 +46,6 @@ buildPythonPackage rec {
|
||||||
propagatedBuildInputs = [
|
propagatedBuildInputs = [
|
||||||
absl-py
|
absl-py
|
||||||
etils
|
etils
|
||||||
ml-dtypes
|
|
||||||
numpy
|
numpy
|
||||||
opt-einsum
|
opt-einsum
|
||||||
scipy
|
scipy
|
||||||
|
@ -98,12 +96,24 @@ buildPythonPackage rec {
|
||||||
"testScanGrad_jit_scan"
|
"testScanGrad_jit_scan"
|
||||||
];
|
];
|
||||||
|
|
||||||
disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
|
# See https://github.com/google/jax/issues/11722. This is a temporary fix in
|
||||||
|
# order to unblock etils, and upgrading jax/jaxlib to the latest version. See
|
||||||
|
# https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
|
||||||
|
disabledTestPaths = [
|
||||||
|
"tests/api_test.py"
|
||||||
|
"tests/core_test.py"
|
||||||
|
"tests/lax_numpy_indexing_test.py"
|
||||||
|
"tests/lax_numpy_test.py"
|
||||||
|
"tests/nn_test.py"
|
||||||
|
"tests/random_test.py"
|
||||||
|
"tests/sparse_test.py"
|
||||||
|
] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
|
||||||
# RuntimeWarning: invalid value encountered in cast
|
# RuntimeWarning: invalid value encountered in cast
|
||||||
"tests/lax_test.py"
|
"tests/lax_test.py"
|
||||||
];
|
];
|
||||||
|
|
||||||
pythonImportsCheck = [ "jax" ];
|
# As of 0.3.22, `import jax` does not work without jaxlib being installed.
|
||||||
|
pythonImportsCheck = [ ];
|
||||||
|
|
||||||
meta = with lib; {
|
meta = with lib; {
|
||||||
description = "Differentiate, compile, and transform Numpy code";
|
description = "Differentiate, compile, and transform Numpy code";
|
||||||
|
|
|
@ -18,12 +18,11 @@
|
||||||
, autoPatchelfHook
|
, autoPatchelfHook
|
||||||
, buildPythonPackage
|
, buildPythonPackage
|
||||||
, config
|
, config
|
||||||
, fetchPypi
|
, cudnn ? cudaPackages.cudnn
|
||||||
, fetchurl
|
, fetchurl
|
||||||
, flatbuffers
|
, flatbuffers
|
||||||
, jaxlib
|
, isPy39
|
||||||
, lib
|
, lib
|
||||||
, ml-dtypes
|
|
||||||
, python
|
, python
|
||||||
, scipy
|
, scipy
|
||||||
, stdenv
|
, stdenv
|
||||||
|
@ -36,57 +35,46 @@ let
|
||||||
inherit (cudaPackages) cudatoolkit cudnn;
|
inherit (cudaPackages) cudatoolkit cudnn;
|
||||||
in
|
in
|
||||||
|
|
||||||
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux;
|
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
|
||||||
|
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";
|
||||||
|
|
||||||
let
|
let
|
||||||
version = "0.4.12";
|
version = "0.4.4";
|
||||||
|
|
||||||
inherit (python) pythonVersion;
|
|
||||||
|
|
||||||
# As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
|
|
||||||
# official instructions recommend installing CPU-only versions via PyPI.
|
|
||||||
cpuSrcs =
|
|
||||||
let
|
|
||||||
getSrcFromPypi = { platform, hash }: fetchPypi {
|
|
||||||
inherit version platform hash;
|
|
||||||
pname = "jaxlib";
|
|
||||||
format = "wheel";
|
|
||||||
# See the `disabled` attr comment below.
|
|
||||||
dist = "cp310";
|
|
||||||
python = "cp310";
|
|
||||||
abi = "cp310";
|
|
||||||
};
|
|
||||||
in
|
|
||||||
{
|
|
||||||
"x86_64-linux" = getSrcFromPypi {
|
|
||||||
platform = "manylinux2014_x86_64";
|
|
||||||
hash = "sha256-8ef5aMP7M3/FetSqfdz2OCaVCt6CLHRSMMsVtV2bCLc=";
|
|
||||||
};
|
|
||||||
"aarch64-darwin" = getSrcFromPypi {
|
|
||||||
platform = "macosx_11_0_arm64";
|
|
||||||
hash = "sha256-Opg/DB4wAVSm5L3+G470HiBPDoR/BO4qP0OX9HSbeSo=";
|
|
||||||
};
|
|
||||||
"x86_64-darwin" = getSrcFromPypi {
|
|
||||||
platform = "macosx_10_14_x86_64";
|
|
||||||
hash = "sha256-I4zX1vv4L5Ik9eWrJ8fKd0EIt5C9XTN4JlfB8hH+l5c=";
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
|
pythonVersion = python.pythonVersion;
|
||||||
|
|
||||||
# Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
|
# Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
|
||||||
# When upgrading, you can get these hashes from prefetch.sh. See
|
# When upgrading, you can get these hashes from prefetch.sh. See
|
||||||
# https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
|
# https://github.com/google/jax/issues/12879 as to why this specific URL is
|
||||||
gpuSrc = fetchurl {
|
# the correct index.
|
||||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
|
cpuSrcs = {
|
||||||
hash = "sha256-xc6Nje0WHtMC5nV75zvdN53xSuNTbFSsz1FzHKd8Muo=";
|
"x86_64-linux" = fetchurl {
|
||||||
|
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
|
||||||
|
hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ=";
|
||||||
|
};
|
||||||
|
"aarch64-darwin" = fetchurl {
|
||||||
|
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
|
||||||
|
hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U=";
|
||||||
|
};
|
||||||
|
"x86_64-darwin" = fetchurl {
|
||||||
|
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl";
|
||||||
|
hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok=";
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
gpuSrc = fetchurl {
|
||||||
|
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
|
||||||
|
hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk=";
|
||||||
|
};
|
||||||
in
|
in
|
||||||
buildPythonPackage {
|
buildPythonPackage rec {
|
||||||
pname = "jaxlib";
|
pname = "jaxlib";
|
||||||
inherit version;
|
inherit version;
|
||||||
format = "wheel";
|
format = "wheel";
|
||||||
|
|
||||||
|
# At the time of writing (2022-10-19), there are releases for <=3.10.
|
||||||
|
# Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs
|
||||||
|
# python version.
|
||||||
disabled = !(pythonVersion == "3.10");
|
disabled = !(pythonVersion == "3.10");
|
||||||
|
|
||||||
# See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
|
# See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
|
||||||
|
@ -99,10 +87,9 @@ buildPythonPackage {
|
||||||
|
|
||||||
# Prebuilt wheels are dynamically linked against things that nix can't find.
|
# Prebuilt wheels are dynamically linked against things that nix can't find.
|
||||||
# Run `autoPatchelfHook` to automagically fix them.
|
# Run `autoPatchelfHook` to automagically fix them.
|
||||||
nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ]
|
nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ];
|
||||||
++ lib.optionals cudaSupport [ addOpenGLRunpath ];
|
|
||||||
# Dynamic link dependencies
|
# Dynamic link dependencies
|
||||||
buildInputs = [ stdenv.cc.cc.lib ];
|
buildInputs = [ stdenv.cc.cc ];
|
||||||
|
|
||||||
# jaxlib contains shared libraries that open other shared libraries via dlopen
|
# jaxlib contains shared libraries that open other shared libraries via dlopen
|
||||||
# and these implicit dependencies are not recognized by ldd or
|
# and these implicit dependencies are not recognized by ldd or
|
||||||
|
@ -126,12 +113,7 @@ buildPythonPackage {
|
||||||
done
|
done
|
||||||
'';
|
'';
|
||||||
|
|
||||||
propagatedBuildInputs = [
|
propagatedBuildInputs = [ absl-py flatbuffers scipy ];
|
||||||
absl-py
|
|
||||||
flatbuffers
|
|
||||||
ml-dtypes
|
|
||||||
scipy
|
|
||||||
];
|
|
||||||
|
|
||||||
# Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
|
# Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
|
||||||
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
|
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
|
||||||
|
@ -141,7 +123,7 @@ buildPythonPackage {
|
||||||
ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
|
ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
|
||||||
'';
|
'';
|
||||||
|
|
||||||
inherit (jaxlib) pythonImportsCheck;
|
pythonImportsCheck = [ "jaxlib" ];
|
||||||
|
|
||||||
meta = with lib; {
|
meta = with lib; {
|
||||||
description = "XLA library for JAX";
|
description = "XLA library for JAX";
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
# Build-time dependencies:
|
# Build-time dependencies:
|
||||||
, addOpenGLRunpath
|
, addOpenGLRunpath
|
||||||
, bazel_6
|
, bazel_5
|
||||||
, binutils
|
, binutils
|
||||||
, buildBazelPackage
|
, buildBazelPackage
|
||||||
, buildPythonPackage
|
, buildPythonPackage
|
||||||
|
@ -26,7 +26,6 @@
|
||||||
# Python dependencies:
|
# Python dependencies:
|
||||||
, absl-py
|
, absl-py
|
||||||
, flatbuffers
|
, flatbuffers
|
||||||
, ml-dtypes
|
|
||||||
, numpy
|
, numpy
|
||||||
, scipy
|
, scipy
|
||||||
, six
|
, six
|
||||||
|
@ -36,6 +35,7 @@
|
||||||
, giflib
|
, giflib
|
||||||
, grpc
|
, grpc
|
||||||
, libjpeg_turbo
|
, libjpeg_turbo
|
||||||
|
, protobuf
|
||||||
, python
|
, python
|
||||||
, snappy
|
, snappy
|
||||||
, zlib
|
, zlib
|
||||||
|
@ -53,7 +53,7 @@ let
|
||||||
inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
|
inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
|
||||||
|
|
||||||
pname = "jaxlib";
|
pname = "jaxlib";
|
||||||
version = "0.4.12";
|
version = "0.4.4";
|
||||||
|
|
||||||
meta = with lib; {
|
meta = with lib; {
|
||||||
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
|
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
|
||||||
|
@ -138,15 +138,14 @@ let
|
||||||
bazel-build = buildBazelPackage rec {
|
bazel-build = buildBazelPackage rec {
|
||||||
name = "bazel-build-${pname}-${version}";
|
name = "bazel-build-${pname}-${version}";
|
||||||
|
|
||||||
# See https://github.com/google/jax/blob/main/.bazelversion for the latest.
|
bazel = bazel_5;
|
||||||
bazel = bazel_6;
|
|
||||||
|
|
||||||
src = fetchFromGitHub {
|
src = fetchFromGitHub {
|
||||||
owner = "google";
|
owner = "google";
|
||||||
repo = "jax";
|
repo = "jax";
|
||||||
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
|
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
|
||||||
rev = "refs/tags/${pname}-v${version}";
|
rev = "refs/tags/${pname}-v${version}";
|
||||||
hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y=";
|
hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
|
||||||
};
|
};
|
||||||
|
|
||||||
nativeBuildInputs = [
|
nativeBuildInputs = [
|
||||||
|
@ -170,7 +169,7 @@ let
|
||||||
numpy
|
numpy
|
||||||
openssl
|
openssl
|
||||||
pkgs.flatbuffers
|
pkgs.flatbuffers
|
||||||
pkgs.protobuf
|
protobuf
|
||||||
pybind11
|
pybind11
|
||||||
scipy
|
scipy
|
||||||
six
|
six
|
||||||
|
@ -189,8 +188,7 @@ let
|
||||||
rm -f .bazelversion
|
rm -f .bazelversion
|
||||||
'';
|
'';
|
||||||
|
|
||||||
bazelRunTarget = "//build:build_wheel";
|
bazelTargets = [ "//build:build_wheel" ];
|
||||||
runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ];
|
|
||||||
|
|
||||||
removeRulesCC = false;
|
removeRulesCC = false;
|
||||||
|
|
||||||
|
@ -209,11 +207,7 @@ let
|
||||||
build --action_env=PYENV_ROOT
|
build --action_env=PYENV_ROOT
|
||||||
build --python_path="${python}/bin/python"
|
build --python_path="${python}/bin/python"
|
||||||
build --distinct_host_configuration=false
|
build --distinct_host_configuration=false
|
||||||
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
|
build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
|
||||||
'' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) ''
|
|
||||||
build --config=avx_posix
|
|
||||||
'' + lib.optionalString mklSupport ''
|
|
||||||
build --config=mkl_open_source_only
|
|
||||||
'' + lib.optionalString cudaSupport ''
|
'' + lib.optionalString cudaSupport ''
|
||||||
build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
|
build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
|
||||||
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
|
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
|
||||||
|
@ -240,7 +234,7 @@ let
|
||||||
fetchAttrs = {
|
fetchAttrs = {
|
||||||
TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
|
TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
|
||||||
# we have to force @mkl_dnn_v1 since it's not needed on darwin
|
# we have to force @mkl_dnn_v1 since it's not needed on darwin
|
||||||
bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ];
|
bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
|
||||||
bazelFlags = bazelFlags ++ [
|
bazelFlags = bazelFlags ++ [
|
||||||
"--config=avx_posix"
|
"--config=avx_posix"
|
||||||
] ++ lib.optionals cudaSupport [
|
] ++ lib.optionals cudaSupport [
|
||||||
|
@ -255,9 +249,9 @@ let
|
||||||
|
|
||||||
sha256 =
|
sha256 =
|
||||||
if cudaSupport then
|
if cudaSupport then
|
||||||
"sha256-wpucplv03HQHZ2gWhVq4R798ouPH99T3X4hbu7IRxj4="
|
"sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk="
|
||||||
else
|
else
|
||||||
"sha256-v2tCFifMBJbqweZQ2rsw707Zxehu+B+YtxFk1iHdDgc=";
|
"sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI=";
|
||||||
};
|
};
|
||||||
|
|
||||||
buildAttrs = {
|
buildAttrs = {
|
||||||
|
@ -267,13 +261,25 @@ let
|
||||||
"nsync" # fails to build on darwin
|
"nsync" # fails to build on darwin
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
|
||||||
|
"--config=avx_posix"
|
||||||
|
] ++ lib.optionals cudaSupport [
|
||||||
|
"--config=cuda"
|
||||||
|
] ++ lib.optionals mklSupport [
|
||||||
|
"--config=mkl_open_source_only"
|
||||||
|
];
|
||||||
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
|
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
|
||||||
# 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
|
# 1) Fix pybind11 include paths.
|
||||||
|
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
|
||||||
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
|
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
|
||||||
# 2) Patch python path in the compiler driver.
|
# 3) Patch python path in the compiler driver.
|
||||||
preBuild = lib.optionalString cudaSupport ''
|
preBuild = ''
|
||||||
|
for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
|
||||||
|
sed -i 's@include/pybind11@pybind11@g' $src
|
||||||
|
done
|
||||||
|
'' + lib.optionalString cudaSupport ''
|
||||||
export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
|
export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
|
||||||
patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||||
'' + lib.optionalString stdenv.isDarwin ''
|
'' + lib.optionalString stdenv.isDarwin ''
|
||||||
# Framework search paths aren't added by bintools hook
|
# Framework search paths aren't added by bintools hook
|
||||||
# https://github.com/NixOS/nixpkgs/pull/41914
|
# https://github.com/NixOS/nixpkgs/pull/41914
|
||||||
|
@ -283,12 +289,16 @@ let
|
||||||
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
|
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
|
||||||
--replace "/usr/bin/libtool" "${cctools}/bin/libtool"
|
--replace "/usr/bin/libtool" "${cctools}/bin/libtool"
|
||||||
'' + (if stdenv.cc.isGNU then ''
|
'' + (if stdenv.cc.isGNU then ''
|
||||||
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
|
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||||
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
|
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||||
'' else if stdenv.cc.isClang then ''
|
'' else if stdenv.cc.isClang then ''
|
||||||
sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
|
sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||||
sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
|
sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||||
'' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
|
'' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
|
||||||
|
|
||||||
|
installPhase = ''
|
||||||
|
./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch}
|
||||||
|
'';
|
||||||
};
|
};
|
||||||
|
|
||||||
inherit meta;
|
inherit meta;
|
||||||
|
@ -335,19 +345,13 @@ buildPythonPackage {
|
||||||
grpc
|
grpc
|
||||||
jsoncpp
|
jsoncpp
|
||||||
libjpeg_turbo
|
libjpeg_turbo
|
||||||
ml-dtypes
|
|
||||||
numpy
|
numpy
|
||||||
scipy
|
scipy
|
||||||
six
|
six
|
||||||
snappy
|
snappy
|
||||||
];
|
];
|
||||||
|
|
||||||
pythonImportsCheck = [
|
pythonImportsCheck = [ "jaxlib" ];
|
||||||
"jaxlib"
|
|
||||||
# `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
|
|
||||||
"jaxlib.cpu_feature_guard"
|
|
||||||
"jaxlib.xla_client"
|
|
||||||
];
|
|
||||||
|
|
||||||
# Without it there are complaints about libcudart.so.11.0 not being found
|
# Without it there are complaints about libcudart.so.11.0 not being found
|
||||||
# because RPATH path entries added above are stripped.
|
# because RPATH path entries added above are stripped.
|
||||||
|
|
|
@ -1,15 +1,7 @@
|
||||||
#!/usr/bin/env bash
|
version="$1"
|
||||||
|
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)"
|
||||||
prefetch () {
|
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)"
|
||||||
expr="(import <nixpkgs> { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url"
|
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)"
|
||||||
url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r)
|
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)"
|
||||||
echo "$url"
|
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)"
|
||||||
sha256=$(nix-prefetch-url "$url")
|
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)"
|
||||||
nix hash to-sri --type sha256 "$sha256"
|
|
||||||
echo
|
|
||||||
}
|
|
||||||
|
|
||||||
prefetch "x86_64-linux" "false"
|
|
||||||
prefetch "aarch64-darwin" "false"
|
|
||||||
prefetch "x86_64-darwin" "false"
|
|
||||||
prefetch "x86_64-linux" "true"
|
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
{ lib
|
|
||||||
, buildPythonPackage
|
|
||||||
, fetchFromGitHub
|
|
||||||
, numpy
|
|
||||||
, pybind11
|
|
||||||
, pythonOlder
|
|
||||||
}:
|
|
||||||
|
|
||||||
buildPythonPackage rec {
|
|
||||||
pname = "ml-dtypes";
|
|
||||||
version = "0.2.0";
|
|
||||||
|
|
||||||
disabled = pythonOlder "3.7";
|
|
||||||
|
|
||||||
src = fetchFromGitHub {
|
|
||||||
owner = "jax-ml";
|
|
||||||
repo = "ml_dtypes";
|
|
||||||
rev = "refs/tags/v${version}";
|
|
||||||
hash = "sha256-eqajWUwylIYsS8gzEaCZLLr+1+34LXWhfKBjuwsEhhI=";
|
|
||||||
# Since this upstream patch (https://github.com/jax-ml/ml_dtypes/commit/1bfd097e794413b0d465fa34f2eff0f3828ff521),
|
|
||||||
# the attempts to use the nixpkgs packaged eigen dependency have failed.
|
|
||||||
# Hence, we rely on the bundled eigen library.
|
|
||||||
fetchSubmodules = true;
|
|
||||||
};
|
|
||||||
|
|
||||||
nativeBuildInputs = [ pybind11 ];
|
|
||||||
|
|
||||||
propagatedBuildInputs = [ numpy ];
|
|
||||||
|
|
||||||
pythonImportsCheck = [ "ml_dtypes" ];
|
|
||||||
|
|
||||||
meta = with lib; {
|
|
||||||
description = "A stand-alone implementation of several NumPy dtype extensions used in machine learning libraries";
|
|
||||||
homepage = "https://github.com/jax-ml/ml_dtypes";
|
|
||||||
license = licenses.asl20;
|
|
||||||
maintainers = with maintainers; [ GaetanLepage samuela ];
|
|
||||||
};
|
|
||||||
}
|
|
|
@ -5310,6 +5310,7 @@ self: super: with self; {
|
||||||
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
|
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
|
||||||
inherit (pkgs.config) cudaSupport;
|
inherit (pkgs.config) cudaSupport;
|
||||||
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
|
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
|
||||||
|
protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21
|
||||||
};
|
};
|
||||||
|
|
||||||
jaxlib = self.jaxlib-build;
|
jaxlib = self.jaxlib-build;
|
||||||
|
@ -6563,8 +6564,6 @@ self: super: with self; {
|
||||||
|
|
||||||
ml-collections = callPackage ../development/python-modules/ml-collections { };
|
ml-collections = callPackage ../development/python-modules/ml-collections { };
|
||||||
|
|
||||||
ml-dtypes = callPackage ../development/python-modules/ml-dtypes { };
|
|
||||||
|
|
||||||
mlflow = callPackage ../development/python-modules/mlflow { };
|
mlflow = callPackage ../development/python-modules/mlflow { };
|
||||||
|
|
||||||
mlrose = callPackage ../development/python-modules/mlrose { };
|
mlrose = callPackage ../development/python-modules/mlrose { };
|
||||||
|
|
Loading…
Reference in a new issue