Revert "Update JAX"

This commit is contained in:
Nick Cao 2023-08-01 21:23:27 -06:00 committed by GitHub
parent cf6c391838
commit 8423edb179
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 103 additions and 173 deletions

View file

@ -10,12 +10,9 @@ args@{
, bazelFlags ? [] , bazelFlags ? []
, bazelBuildFlags ? [] , bazelBuildFlags ? []
, bazelTestFlags ? [] , bazelTestFlags ? []
, bazelRunFlags ? []
, runTargetFlags ? []
, bazelFetchFlags ? [] , bazelFetchFlags ? []
, bazelTargets ? [] , bazelTargets
, bazelTestTargets ? [] , bazelTestTargets ? []
, bazelRunTarget ? null
, buildAttrs , buildAttrs
, fetchAttrs , fetchAttrs
@ -49,23 +46,17 @@ args@{
let let
fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // { fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // {
inherit name = name;
name bazelFlags = bazelFlags;
bazelFlags bazelBuildFlags = bazelBuildFlags;
bazelBuildFlags bazelTestFlags = bazelTestFlags;
bazelTestFlags bazelFetchFlags = bazelFetchFlags;
bazelRunFlags bazelTestTargets = bazelTestTargets;
runTargetFlags dontAddBazelOpts = dontAddBazelOpts;
bazelFetchFlags
bazelTargets
bazelTestTargets
bazelRunTarget
dontAddBazelOpts
;
}; };
fBuildAttrs = fArgs // buildAttrs; fBuildAttrs = fArgs // buildAttrs;
fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ]; fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ];
bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }: bazelCmd = { cmd, additionalFlags, targets }:
lib.optionalString (targets != [ ]) '' lib.optionalString (targets != [ ]) ''
# See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables] # See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables]
BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \ BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \
@ -82,8 +73,7 @@ let
"''${host_linkopts[@]}" \ "''${host_linkopts[@]}" \
$bazelFlags \ $bazelFlags \
${lib.strings.concatStringsSep " " additionalFlags} \ ${lib.strings.concatStringsSep " " additionalFlags} \
${lib.strings.concatStringsSep " " targets} \ ${lib.strings.concatStringsSep " " targets}
${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags}
''; '';
# we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so: # we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so:
# chmod: cannot operate on dangling symlink '$symlink' # chmod: cannot operate on dangling symlink '$symlink'
@ -272,15 +262,6 @@ stdenv.mkDerivation (fBuildAttrs // {
targets = fBuildAttrs.bazelTargets; targets = fBuildAttrs.bazelTargets;
} }
} }
${
bazelCmd {
cmd = "run";
additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ];
# Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list.
targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ];
targetRunFlags = fBuildAttrs.runTargetFlags;
}
}
runHook postBuild runHook postBuild
''; '';
}) })

View file

@ -8,7 +8,6 @@
, jaxlib-bin , jaxlib-bin
, lapack , lapack
, matplotlib , matplotlib
, ml-dtypes
, numpy , numpy
, opt-einsum , opt-einsum
, pytestCheckHook , pytestCheckHook
@ -28,7 +27,7 @@ let
in in
buildPythonPackage rec { buildPythonPackage rec {
pname = "jax"; pname = "jax";
version = "0.4.12"; version = "0.4.5";
format = "setuptools"; format = "setuptools";
disabled = pythonOlder "3.7"; disabled = pythonOlder "3.7";
@ -38,7 +37,7 @@ buildPythonPackage rec {
repo = pname; repo = pname;
# google/jax contains tags for jax and jaxlib. Only use jax tags! # google/jax contains tags for jax and jaxlib. Only use jax tags!
rev = "refs/tags/${pname}-v${version}"; rev = "refs/tags/${pname}-v${version}";
hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
}; };
# jaxlib is _not_ included in propagatedBuildInputs because there are # jaxlib is _not_ included in propagatedBuildInputs because there are
@ -47,7 +46,6 @@ buildPythonPackage rec {
propagatedBuildInputs = [ propagatedBuildInputs = [
absl-py absl-py
etils etils
ml-dtypes
numpy numpy
opt-einsum opt-einsum
scipy scipy
@ -98,12 +96,24 @@ buildPythonPackage rec {
"testScanGrad_jit_scan" "testScanGrad_jit_scan"
]; ];
disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # See https://github.com/google/jax/issues/11722. This is a temporary fix in
# order to unblock etils, and upgrading jax/jaxlib to the latest version. See
# https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
disabledTestPaths = [
"tests/api_test.py"
"tests/core_test.py"
"tests/lax_numpy_indexing_test.py"
"tests/lax_numpy_test.py"
"tests/nn_test.py"
"tests/random_test.py"
"tests/sparse_test.py"
] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
# RuntimeWarning: invalid value encountered in cast # RuntimeWarning: invalid value encountered in cast
"tests/lax_test.py" "tests/lax_test.py"
]; ];
pythonImportsCheck = [ "jax" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed.
pythonImportsCheck = [ ];
meta = with lib; { meta = with lib; {
description = "Differentiate, compile, and transform Numpy code"; description = "Differentiate, compile, and transform Numpy code";

View file

@ -18,12 +18,11 @@
, autoPatchelfHook , autoPatchelfHook
, buildPythonPackage , buildPythonPackage
, config , config
, fetchPypi , cudnn ? cudaPackages.cudnn
, fetchurl , fetchurl
, flatbuffers , flatbuffers
, jaxlib , isPy39
, lib , lib
, ml-dtypes
, python , python
, scipy , scipy
, stdenv , stdenv
@ -36,57 +35,46 @@ let
inherit (cudaPackages) cudatoolkit cudnn; inherit (cudaPackages) cudatoolkit cudnn;
in in
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux; assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";
let let
version = "0.4.12"; version = "0.4.4";
inherit (python) pythonVersion;
# As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
# official instructions recommend installing CPU-only versions via PyPI.
cpuSrcs =
let
getSrcFromPypi = { platform, hash }: fetchPypi {
inherit version platform hash;
pname = "jaxlib";
format = "wheel";
# See the `disabled` attr comment below.
dist = "cp310";
python = "cp310";
abi = "cp310";
};
in
{
"x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
hash = "sha256-8ef5aMP7M3/FetSqfdz2OCaVCt6CLHRSMMsVtV2bCLc=";
};
"aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
hash = "sha256-Opg/DB4wAVSm5L3+G470HiBPDoR/BO4qP0OX9HSbeSo=";
};
"x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
hash = "sha256-I4zX1vv4L5Ik9eWrJ8fKd0EIt5C9XTN4JlfB8hH+l5c=";
};
};
pythonVersion = python.pythonVersion;
# Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
# When upgrading, you can get these hashes from prefetch.sh. See # When upgrading, you can get these hashes from prefetch.sh. See
# https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. # https://github.com/google/jax/issues/12879 as to why this specific URL is
gpuSrc = fetchurl { # the correct index.
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; cpuSrcs = {
hash = "sha256-xc6Nje0WHtMC5nV75zvdN53xSuNTbFSsz1FzHKd8Muo="; "x86_64-linux" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ=";
};
"aarch64-darwin" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U=";
};
"x86_64-darwin" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl";
hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok=";
};
}; };
gpuSrc = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk=";
};
in in
buildPythonPackage { buildPythonPackage rec {
pname = "jaxlib"; pname = "jaxlib";
inherit version; inherit version;
format = "wheel"; format = "wheel";
# At the time of writing (2022-10-19), there are releases for <=3.10.
# Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs
# python version.
disabled = !(pythonVersion == "3.10"); disabled = !(pythonVersion == "3.10");
# See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
@ -99,10 +87,9 @@ buildPythonPackage {
# Prebuilt wheels are dynamically linked against things that nix can't find. # Prebuilt wheels are dynamically linked against things that nix can't find.
# Run `autoPatchelfHook` to automagically fix them. # Run `autoPatchelfHook` to automagically fix them.
nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ] nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ];
++ lib.optionals cudaSupport [ addOpenGLRunpath ];
# Dynamic link dependencies # Dynamic link dependencies
buildInputs = [ stdenv.cc.cc.lib ]; buildInputs = [ stdenv.cc.cc ];
# jaxlib contains shared libraries that open other shared libraries via dlopen # jaxlib contains shared libraries that open other shared libraries via dlopen
# and these implicit dependencies are not recognized by ldd or # and these implicit dependencies are not recognized by ldd or
@ -126,12 +113,7 @@ buildPythonPackage {
done done
''; '';
propagatedBuildInputs = [ propagatedBuildInputs = [ absl-py flatbuffers scipy ];
absl-py
flatbuffers
ml-dtypes
scipy
];
# Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
@ -141,7 +123,7 @@ buildPythonPackage {
ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
''; '';
inherit (jaxlib) pythonImportsCheck; pythonImportsCheck = [ "jaxlib" ];
meta = with lib; { meta = with lib; {
description = "XLA library for JAX"; description = "XLA library for JAX";

View file

@ -4,7 +4,7 @@
# Build-time dependencies: # Build-time dependencies:
, addOpenGLRunpath , addOpenGLRunpath
, bazel_6 , bazel_5
, binutils , binutils
, buildBazelPackage , buildBazelPackage
, buildPythonPackage , buildPythonPackage
@ -26,7 +26,6 @@
# Python dependencies: # Python dependencies:
, absl-py , absl-py
, flatbuffers , flatbuffers
, ml-dtypes
, numpy , numpy
, scipy , scipy
, six , six
@ -36,6 +35,7 @@
, giflib , giflib
, grpc , grpc
, libjpeg_turbo , libjpeg_turbo
, protobuf
, python , python
, snappy , snappy
, zlib , zlib
@ -53,7 +53,7 @@ let
inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
pname = "jaxlib"; pname = "jaxlib";
version = "0.4.12"; version = "0.4.4";
meta = with lib; { meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@ -138,15 +138,14 @@ let
bazel-build = buildBazelPackage rec { bazel-build = buildBazelPackage rec {
name = "bazel-build-${pname}-${version}"; name = "bazel-build-${pname}-${version}";
# See https://github.com/google/jax/blob/main/.bazelversion for the latest. bazel = bazel_5;
bazel = bazel_6;
src = fetchFromGitHub { src = fetchFromGitHub {
owner = "google"; owner = "google";
repo = "jax"; repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags! # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
rev = "refs/tags/${pname}-v${version}"; rev = "refs/tags/${pname}-v${version}";
hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
}; };
nativeBuildInputs = [ nativeBuildInputs = [
@ -170,7 +169,7 @@ let
numpy numpy
openssl openssl
pkgs.flatbuffers pkgs.flatbuffers
pkgs.protobuf protobuf
pybind11 pybind11
scipy scipy
six six
@ -189,8 +188,7 @@ let
rm -f .bazelversion rm -f .bazelversion
''; '';
bazelRunTarget = "//build:build_wheel"; bazelTargets = [ "//build:build_wheel" ];
runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ];
removeRulesCC = false; removeRulesCC = false;
@ -209,11 +207,7 @@ let
build --action_env=PYENV_ROOT build --action_env=PYENV_ROOT
build --python_path="${python}/bin/python" build --python_path="${python}/bin/python"
build --distinct_host_configuration=false build --distinct_host_configuration=false
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
'' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) ''
build --config=avx_posix
'' + lib.optionalString mklSupport ''
build --config=mkl_open_source_only
'' + lib.optionalString cudaSupport '' '' + lib.optionalString cudaSupport ''
build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
build --action_env CUDNN_INSTALL_PATH="${cudnn}" build --action_env CUDNN_INSTALL_PATH="${cudnn}"
@ -240,7 +234,7 @@ let
fetchAttrs = { fetchAttrs = {
TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
# we have to force @mkl_dnn_v1 since it's not needed on darwin # we have to force @mkl_dnn_v1 since it's not needed on darwin
bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ]; bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
bazelFlags = bazelFlags ++ [ bazelFlags = bazelFlags ++ [
"--config=avx_posix" "--config=avx_posix"
] ++ lib.optionals cudaSupport [ ] ++ lib.optionals cudaSupport [
@ -255,9 +249,9 @@ let
sha256 = sha256 =
if cudaSupport then if cudaSupport then
"sha256-wpucplv03HQHZ2gWhVq4R798ouPH99T3X4hbu7IRxj4=" "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk="
else else
"sha256-v2tCFifMBJbqweZQ2rsw707Zxehu+B+YtxFk1iHdDgc="; "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI=";
}; };
buildAttrs = { buildAttrs = {
@ -267,13 +261,25 @@ let
"nsync" # fails to build on darwin "nsync" # fails to build on darwin
]); ]);
bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
"--config=avx_posix"
] ++ lib.optionals cudaSupport [
"--config=cuda"
] ++ lib.optionals mklSupport [
"--config=mkl_open_source_only"
];
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet. # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
# 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # 1) Fix pybind11 include paths.
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
# loading multiple extensions in the same python program due to duplicate protobuf DBs. # loading multiple extensions in the same python program due to duplicate protobuf DBs.
# 2) Patch python path in the compiler driver. # 3) Patch python path in the compiler driver.
preBuild = lib.optionalString cudaSupport '' preBuild = ''
for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
sed -i 's@include/pybind11@pybind11@g' $src
done
'' + lib.optionalString cudaSupport ''
export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib" export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
'' + lib.optionalString stdenv.isDarwin '' '' + lib.optionalString stdenv.isDarwin ''
# Framework search paths aren't added by bintools hook # Framework search paths aren't added by bintools hook
# https://github.com/NixOS/nixpkgs/pull/41914 # https://github.com/NixOS/nixpkgs/pull/41914
@ -283,12 +289,16 @@ let
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
--replace "/usr/bin/libtool" "${cctools}/bin/libtool" --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
'' + (if stdenv.cc.isGNU then '' '' + (if stdenv.cc.isGNU then ''
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
'' else if stdenv.cc.isClang then '' '' else if stdenv.cc.isClang then ''
sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
'' else throw "Unsupported stdenv.cc: ${stdenv.cc}"); '' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
installPhase = ''
./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch}
'';
}; };
inherit meta; inherit meta;
@ -335,19 +345,13 @@ buildPythonPackage {
grpc grpc
jsoncpp jsoncpp
libjpeg_turbo libjpeg_turbo
ml-dtypes
numpy numpy
scipy scipy
six six
snappy snappy
]; ];
pythonImportsCheck = [ pythonImportsCheck = [ "jaxlib" ];
"jaxlib"
# `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
"jaxlib.cpu_feature_guard"
"jaxlib.xla_client"
];
# Without it there are complaints about libcudart.so.11.0 not being found # Without it there are complaints about libcudart.so.11.0 not being found
# because RPATH path entries added above are stripped. # because RPATH path entries added above are stripped.

View file

@ -1,15 +1,7 @@
#!/usr/bin/env bash version="$1"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)"
prefetch () { nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)"
expr="(import <nixpkgs> { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)"
url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r) nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)"
echo "$url" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)"
sha256=$(nix-prefetch-url "$url") nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)"
nix hash to-sri --type sha256 "$sha256"
echo
}
prefetch "x86_64-linux" "false"
prefetch "aarch64-darwin" "false"
prefetch "x86_64-darwin" "false"
prefetch "x86_64-linux" "true"

View file

@ -1,38 +0,0 @@
{ lib
, buildPythonPackage
, fetchFromGitHub
, numpy
, pybind11
, pythonOlder
}:
buildPythonPackage rec {
pname = "ml-dtypes";
version = "0.2.0";
disabled = pythonOlder "3.7";
src = fetchFromGitHub {
owner = "jax-ml";
repo = "ml_dtypes";
rev = "refs/tags/v${version}";
hash = "sha256-eqajWUwylIYsS8gzEaCZLLr+1+34LXWhfKBjuwsEhhI=";
# Since this upstream patch (https://github.com/jax-ml/ml_dtypes/commit/1bfd097e794413b0d465fa34f2eff0f3828ff521),
# the attempts to use the nixpkgs packaged eigen dependency have failed.
# Hence, we rely on the bundled eigen library.
fetchSubmodules = true;
};
nativeBuildInputs = [ pybind11 ];
propagatedBuildInputs = [ numpy ];
pythonImportsCheck = [ "ml_dtypes" ];
meta = with lib; {
description = "A stand-alone implementation of several NumPy dtype extensions used in machine learning libraries";
homepage = "https://github.com/jax-ml/ml_dtypes";
license = licenses.asl20;
maintainers = with maintainers; [ GaetanLepage samuela ];
};
}

View file

@ -5310,6 +5310,7 @@ self: super: with self; {
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'. # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
inherit (pkgs.config) cudaSupport; inherit (pkgs.config) cudaSupport;
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit; IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21
}; };
jaxlib = self.jaxlib-build; jaxlib = self.jaxlib-build;
@ -6563,8 +6564,6 @@ self: super: with self; {
ml-collections = callPackage ../development/python-modules/ml-collections { }; ml-collections = callPackage ../development/python-modules/ml-collections { };
ml-dtypes = callPackage ../development/python-modules/ml-dtypes { };
mlflow = callPackage ../development/python-modules/mlflow { }; mlflow = callPackage ../development/python-modules/mlflow { };
mlrose = callPackage ../development/python-modules/mlrose { }; mlrose = callPackage ../development/python-modules/mlrose { };