Merge pull request #159099 from samuela/upkeep-bot/python3Packages.jax-0.3.0-1644540822

python3Packages.jax: 0.2.28 -> 0.3.0
This commit is contained in:
Samuel Ainsworth 2022-02-14 13:48:40 -08:00 committed by GitHub
commit b0f066827d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 18 deletions

View file

@ -6,6 +6,7 @@
, numpy
, opt-einsum
, pytestCheckHook
, pytest-xdist
, pythonOlder
, scipy
, typing-extensions
@ -13,7 +14,7 @@
buildPythonPackage rec {
pname = "jax";
version = "0.2.28";
version = "0.3.0";
format = "setuptools";
disabled = pythonOlder "3.7";
@ -22,7 +23,7 @@ buildPythonPackage rec {
owner = "google";
repo = pname;
rev = "${pname}-v${version}";
sha256 = "1ky442zi5i8b5mk284s0i7dk8rh6vi9dvyqfscpij88g37clgpp0";
sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72";
};
patches = [
@ -45,6 +46,7 @@ buildPythonPackage rec {
checkInputs = [
jaxlib
pytestCheckHook
pytest-xdist
];
# NOTE: Don't run the tests in the expiremental directory as they require flax
@ -52,6 +54,7 @@ buildPythonPackage rec {
# Not a big deal, this is how the JAX docs suggest running the test suite
# anyhow.
pytestFlagsArray = [
"-n auto"
"-W ignore::DeprecationWarning"
"tests/"
];

View file

@ -13,11 +13,20 @@
# * https://github.com/google/jax/issues/971#issuecomment-508216439
# * https://github.com/google/jax/issues/5723#issuecomment-913038780
{ addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config
, fetchurl, isPy39, lib, stdenv
# propagatedBuildInputs
, absl-py, flatbuffers, scipy, cudatoolkit_11, cudnn
# Options:
{ absl-py
, addOpenGLRunpath
, autoPatchelfHook
, buildPythonPackage
, config
, cudatoolkit_11
, cudnn
, fetchurl
, flatbuffers
, isPy39
, lib
, scipy
, stdenv
# Options:
, cudaSupport ? config.cudaSupport or false
}:
@ -32,7 +41,7 @@ let
in
buildPythonPackage rec {
pname = "jaxlib";
version = "0.1.75";
version = "0.3.0";
format = "wheel";
# At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting
@ -44,7 +53,7 @@ buildPythonPackage rec {
src = {
cpu = fetchurl {
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
sha256 = "1davmx9dvai8dq3h5ac82634gjhv6l46kq6baajrxjqczbp0w7m6";
sha256 = "151p4vqli8x0iqgrzrr8piqk7d76a2xq2krf23jlb142iam5bw01";
};
gpu = fetchurl {
# Note that there's also a release targeting cuDNN 8.2, but unfortunately
@ -52,7 +61,7 @@ buildPythonPackage rec {
# Check pkgs/development/libraries/science/math/cudnn/default.nix for more
# details.
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl";
sha256 = "1mk618lq1q5x0dc3xbid8bim59l9j6l47xq232gdbn401ykrid7r";
sha256 = "0z15rdw3a8sq51rpjmfc41ix1q095aasl79rvlib85ir6f3wh2h8";
# This is what the cuDNN 8.2 download looks like for future reference:
# url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl";
@ -95,8 +104,8 @@ buildPythonPackage rec {
meta = with lib; {
description = "XLA library for JAX";
homepage = "https://github.com/google/jax";
license = licenses.asl20;
homepage = "https://github.com/google/jax";
license = licenses.asl20;
maintainers = with maintainers; [ samuela ];
platforms = [ "x86_64-linux" ];
};

View file

@ -4,7 +4,7 @@
# Build-time dependencies:
, addOpenGLRunpath
, bazel_4
, bazel_5
, binutils
, buildBazelPackage
, buildPythonPackage
@ -50,7 +50,7 @@
let
pname = "jaxlib";
version = "0.1.75";
version = "0.3.0";
meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@ -82,13 +82,13 @@ let
bazel-build = buildBazelPackage {
name = "bazel-build-${pname}-${version}";
bazel = bazel_4;
bazel = bazel_5;
src = fetchFromGitHub {
owner = "google";
repo = "jax";
rev = "${pname}-v${version}";
sha256 = "01ks4djbpjsxjy2zwdwv3h00sgwi4ps3jz75swddrw2f56zjdmw4";
sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72";
};
nativeBuildInputs = [
@ -216,9 +216,9 @@ let
fetchAttrs = {
sha256 =
if cudaSupport then
"1lyipbflqd1y5cdj4hdml5h1inbr0wwfgp6xw5p5623qv3im16lh"
"1k0rjxqjm703gd9navwzx5x3874b4dxamr62m1fxhm79d271zxis"
else
"09kapzpfwnlr6ghmgwac232bqf2a57mm1brz4cvfx8mlg8bbaw63";
"0ivah1w41jcj13jm740qzwx5h0ia8vbj71pjgd0zrfk3c92kll41";
};
buildAttrs = {
@ -229,12 +229,17 @@ let
# 2) Force static protobuf linkage to prevent crashes on loading multiple extensions
# in the same python program due to duplicate protobuf DBs.
# 3) Patch python path in the compiler driver.
# 4) Patch tensorflow sources to work with later versions of protobuf. See
# https://github.com/google/jax/issues/9534. Note that this should be
# removed on the next release after 0.3.0.
preBuild = ''
for src in ./jaxlib/*.{cc,h}; do
sed -i 's@include/pybind11@pybind11@g' $src
done
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \
--replace "status.message()" "std::string{status.message()}"
'' + lib.optionalString cudaSupport ''
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
'';