Merge pull request #227145 from GaetanLepage/jax

Update JAX
This commit is contained in:
Nick Cao 2023-08-01 19:42:01 -06:00 committed by GitHub
commit 2673bcc912
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 173 additions and 103 deletions

View file

@ -10,9 +10,12 @@ args@{
, bazelFlags ? []
, bazelBuildFlags ? []
, bazelTestFlags ? []
, bazelRunFlags ? []
, runTargetFlags ? []
, bazelFetchFlags ? []
, bazelTargets
, bazelTargets ? []
, bazelTestTargets ? []
, bazelRunTarget ? null
, buildAttrs
, fetchAttrs
@ -46,17 +49,23 @@ args@{
let
fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // {
name = name;
bazelFlags = bazelFlags;
bazelBuildFlags = bazelBuildFlags;
bazelTestFlags = bazelTestFlags;
bazelFetchFlags = bazelFetchFlags;
bazelTestTargets = bazelTestTargets;
dontAddBazelOpts = dontAddBazelOpts;
inherit
name
bazelFlags
bazelBuildFlags
bazelTestFlags
bazelRunFlags
runTargetFlags
bazelFetchFlags
bazelTargets
bazelTestTargets
bazelRunTarget
dontAddBazelOpts
;
};
fBuildAttrs = fArgs // buildAttrs;
fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ];
bazelCmd = { cmd, additionalFlags, targets }:
bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }:
lib.optionalString (targets != [ ]) ''
# See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables]
BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \
@ -73,7 +82,8 @@ let
"''${host_linkopts[@]}" \
$bazelFlags \
${lib.strings.concatStringsSep " " additionalFlags} \
${lib.strings.concatStringsSep " " targets}
${lib.strings.concatStringsSep " " targets} \
${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags}
'';
# we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so:
# chmod: cannot operate on dangling symlink '$symlink'
@ -262,6 +272,15 @@ stdenv.mkDerivation (fBuildAttrs // {
targets = fBuildAttrs.bazelTargets;
}
}
${
bazelCmd {
cmd = "run";
additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ];
# Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list.
targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ];
targetRunFlags = fBuildAttrs.runTargetFlags;
}
}
runHook postBuild
'';
})

View file

@ -8,6 +8,7 @@
, jaxlib-bin
, lapack
, matplotlib
, ml-dtypes
, numpy
, opt-einsum
, pytestCheckHook
@ -27,7 +28,7 @@ let
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.5";
version = "0.4.12";
format = "setuptools";
disabled = pythonOlder "3.7";
@ -37,7 +38,7 @@ buildPythonPackage rec {
repo = pname;
# google/jax contains tags for jax and jaxlib. Only use jax tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y=";
};
# jaxlib is _not_ included in propagatedBuildInputs because there are
@ -46,6 +47,7 @@ buildPythonPackage rec {
propagatedBuildInputs = [
absl-py
etils
ml-dtypes
numpy
opt-einsum
scipy
@ -96,24 +98,12 @@ buildPythonPackage rec {
"testScanGrad_jit_scan"
];
# See https://github.com/google/jax/issues/11722. This is a temporary fix in
# order to unblock etils, and upgrading jax/jaxlib to the latest version. See
# https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
disabledTestPaths = [
"tests/api_test.py"
"tests/core_test.py"
"tests/lax_numpy_indexing_test.py"
"tests/lax_numpy_test.py"
"tests/nn_test.py"
"tests/random_test.py"
"tests/sparse_test.py"
] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
# RuntimeWarning: invalid value encountered in cast
"tests/lax_test.py"
];
# As of 0.3.22, `import jax` does not work without jaxlib being installed.
pythonImportsCheck = [ ];
pythonImportsCheck = [ "jax" ];
meta = with lib; {
description = "Differentiate, compile, and transform Numpy code";

View file

@ -18,11 +18,12 @@
, autoPatchelfHook
, buildPythonPackage
, config
, cudnn ? cudaPackages.cudnn
, fetchPypi
, fetchurl
, flatbuffers
, isPy39
, jaxlib
, lib
, ml-dtypes
, python
, scipy
, stdenv
@ -35,46 +36,57 @@ let
inherit (cudaPackages) cudatoolkit cudnn;
in
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux;
let
version = "0.4.4";
version = "0.4.12";
inherit (python) pythonVersion;
# As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
# official instructions recommend installing CPU-only versions via PyPI.
cpuSrcs =
let
getSrcFromPypi = { platform, hash }: fetchPypi {
inherit version platform hash;
pname = "jaxlib";
format = "wheel";
# See the `disabled` attr comment below.
dist = "cp310";
python = "cp310";
abi = "cp310";
};
in
{
"x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
hash = "sha256-8ef5aMP7M3/FetSqfdz2OCaVCt6CLHRSMMsVtV2bCLc=";
};
"aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
hash = "sha256-Opg/DB4wAVSm5L3+G470HiBPDoR/BO4qP0OX9HSbeSo=";
};
"x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
hash = "sha256-I4zX1vv4L5Ik9eWrJ8fKd0EIt5C9XTN4JlfB8hH+l5c=";
};
};
pythonVersion = python.pythonVersion;
# Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
# When upgrading, you can get these hashes from prefetch.sh. See
# https://github.com/google/jax/issues/12879 as to why this specific URL is
# the correct index.
cpuSrcs = {
"x86_64-linux" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ=";
};
"aarch64-darwin" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U=";
};
"x86_64-darwin" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl";
hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok=";
};
# https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
gpuSrc = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-xc6Nje0WHtMC5nV75zvdN53xSuNTbFSsz1FzHKd8Muo=";
};
gpuSrc = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk=";
};
in
buildPythonPackage rec {
buildPythonPackage {
pname = "jaxlib";
inherit version;
format = "wheel";
# At the time of writing (2022-10-19), there are releases for <=3.10.
# Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs
# python version.
disabled = !(pythonVersion == "3.10");
# See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
@ -87,9 +99,10 @@ buildPythonPackage rec {
# Prebuilt wheels are dynamically linked against things that nix can't find.
# Run `autoPatchelfHook` to automagically fix them.
nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ];
nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ]
++ lib.optionals cudaSupport [ addOpenGLRunpath ];
# Dynamic link dependencies
buildInputs = [ stdenv.cc.cc ];
buildInputs = [ stdenv.cc.cc.lib ];
# jaxlib contains shared libraries that open other shared libraries via dlopen
# and these implicit dependencies are not recognized by ldd or
@ -113,7 +126,12 @@ buildPythonPackage rec {
done
'';
propagatedBuildInputs = [ absl-py flatbuffers scipy ];
propagatedBuildInputs = [
absl-py
flatbuffers
ml-dtypes
scipy
];
# Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
@ -123,7 +141,7 @@ buildPythonPackage rec {
ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
'';
pythonImportsCheck = [ "jaxlib" ];
inherit (jaxlib) pythonImportsCheck;
meta = with lib; {
description = "XLA library for JAX";

View file

@ -4,7 +4,7 @@
# Build-time dependencies:
, addOpenGLRunpath
, bazel_5
, bazel_6
, binutils
, buildBazelPackage
, buildPythonPackage
@ -26,6 +26,7 @@
# Python dependencies:
, absl-py
, flatbuffers
, ml-dtypes
, numpy
, scipy
, six
@ -35,7 +36,6 @@
, giflib
, grpc
, libjpeg_turbo
, protobuf
, python
, snappy
, zlib
@ -53,7 +53,7 @@ let
inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
pname = "jaxlib";
version = "0.4.4";
version = "0.4.12";
meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@ -138,14 +138,15 @@ let
bazel-build = buildBazelPackage rec {
name = "bazel-build-${pname}-${version}";
bazel = bazel_5;
# See https://github.com/google/jax/blob/main/.bazelversion for the latest.
bazel = bazel_6;
src = fetchFromGitHub {
owner = "google";
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y=";
};
nativeBuildInputs = [
@ -169,7 +170,7 @@ let
numpy
openssl
pkgs.flatbuffers
protobuf
pkgs.protobuf
pybind11
scipy
six
@ -188,7 +189,8 @@ let
rm -f .bazelversion
'';
bazelTargets = [ "//build:build_wheel" ];
bazelRunTarget = "//build:build_wheel";
runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ];
removeRulesCC = false;
@ -207,7 +209,11 @@ let
build --action_env=PYENV_ROOT
build --python_path="${python}/bin/python"
build --distinct_host_configuration=false
build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
'' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) ''
build --config=avx_posix
'' + lib.optionalString mklSupport ''
build --config=mkl_open_source_only
'' + lib.optionalString cudaSupport ''
build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
@ -234,7 +240,7 @@ let
fetchAttrs = {
TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
# we have to force @mkl_dnn_v1 since it's not needed on darwin
bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ];
bazelFlags = bazelFlags ++ [
"--config=avx_posix"
] ++ lib.optionals cudaSupport [
@ -249,9 +255,9 @@ let
sha256 =
if cudaSupport then
"sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk="
"sha256-wpucplv03HQHZ2gWhVq4R798ouPH99T3X4hbu7IRxj4="
else
"sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI=";
"sha256-v2tCFifMBJbqweZQ2rsw707Zxehu+B+YtxFk1iHdDgc=";
};
buildAttrs = {
@ -261,25 +267,13 @@ let
"nsync" # fails to build on darwin
]);
bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
"--config=avx_posix"
] ++ lib.optionals cudaSupport [
"--config=cuda"
] ++ lib.optionals mklSupport [
"--config=mkl_open_source_only"
];
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
# 1) Fix pybind11 include paths.
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
# 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
# 3) Patch python path in the compiler driver.
preBuild = ''
for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
sed -i 's@include/pybind11@pybind11@g' $src
done
'' + lib.optionalString cudaSupport ''
# 2) Patch python path in the compiler driver.
preBuild = lib.optionalString cudaSupport ''
export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
'' + lib.optionalString stdenv.isDarwin ''
# Framework search paths aren't added by bintools hook
# https://github.com/NixOS/nixpkgs/pull/41914
@ -289,16 +283,12 @@ let
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
--replace "/usr/bin/libtool" "${cctools}/bin/libtool"
'' + (if stdenv.cc.isGNU then ''
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
'' else if stdenv.cc.isClang then ''
sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
'' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
installPhase = ''
./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch}
'';
};
inherit meta;
@ -345,13 +335,19 @@ buildPythonPackage {
grpc
jsoncpp
libjpeg_turbo
ml-dtypes
numpy
scipy
six
snappy
];
pythonImportsCheck = [ "jaxlib" ];
pythonImportsCheck = [
"jaxlib"
# `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
"jaxlib.cpu_feature_guard"
"jaxlib.xla_client"
];
# Without it there are complaints about libcudart.so.11.0 not being found
# because RPATH path entries added above are stripped.

View file

@ -1,7 +1,15 @@
version="$1"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)"
#!/usr/bin/env bash
prefetch () {
expr="(import <nixpkgs> { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url"
url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r)
echo "$url"
sha256=$(nix-prefetch-url "$url")
nix hash to-sri --type sha256 "$sha256"
echo
}
prefetch "x86_64-linux" "false"
prefetch "aarch64-darwin" "false"
prefetch "x86_64-darwin" "false"
prefetch "x86_64-linux" "true"

View file

@ -0,0 +1,38 @@
{ lib
, buildPythonPackage
, fetchFromGitHub
, numpy
, pybind11
, pythonOlder
}:
buildPythonPackage rec {
pname = "ml-dtypes";
version = "0.2.0";
disabled = pythonOlder "3.7";
src = fetchFromGitHub {
owner = "jax-ml";
repo = "ml_dtypes";
rev = "refs/tags/v${version}";
hash = "sha256-eqajWUwylIYsS8gzEaCZLLr+1+34LXWhfKBjuwsEhhI=";
# Since this upstream patch (https://github.com/jax-ml/ml_dtypes/commit/1bfd097e794413b0d465fa34f2eff0f3828ff521),
# the attempts to use the nixpkgs packaged eigen dependency have failed.
# Hence, we rely on the bundled eigen library.
fetchSubmodules = true;
};
nativeBuildInputs = [ pybind11 ];
propagatedBuildInputs = [ numpy ];
pythonImportsCheck = [ "ml_dtypes" ];
meta = with lib; {
description = "A stand-alone implementation of several NumPy dtype extensions used in machine learning libraries";
homepage = "https://github.com/jax-ml/ml_dtypes";
license = licenses.asl20;
maintainers = with maintainers; [ GaetanLepage samuela ];
};
}

View file

@ -5310,7 +5310,6 @@ self: super: with self; {
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
inherit (pkgs.config) cudaSupport;
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21
};
jaxlib = self.jaxlib-build;
@ -6564,6 +6563,8 @@ self: super: with self; {
ml-collections = callPackage ../development/python-modules/ml-collections { };
ml-dtypes = callPackage ../development/python-modules/ml-dtypes { };
mlflow = callPackage ../development/python-modules/mlflow { };
mlrose = callPackage ../development/python-modules/mlrose { };