Merge pull request #219778 from samuela/samuela/jax

Update JAX and fix aarch64-darwin build
This commit is contained in:
Samuel Ainsworth 2023-04-19 15:53:59 -04:00 committed by GitHub
commit 8faef6de41
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 17 deletions

View file

@ -5,6 +5,7 @@
, etils
, fetchFromGitHub
, jaxlib
, jaxlib-bin
, lapack
, matplotlib
, numpy
@ -13,15 +14,20 @@
, pytest-xdist
, pythonOlder
, scipy
, stdenv
, typing-extensions
}:
let
usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
# jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work
# fine. jaxlib is only used in the checkPhase, so switching backends does not
# impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*.
jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib;
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.1";
version = "0.4.5";
format = "setuptools";
disabled = pythonOlder "3.7";
@ -29,14 +35,14 @@ buildPythonPackage rec {
src = fetchFromGitHub {
owner = "google";
repo = pname;
rev = "refs/tags/jaxlib-v${version}";
hash = "sha256-ajLI0iD0YZRK3/uKSbhlIZGc98MdW174vA34vhoy7Iw=";
# google/jax contains tags for jax and jaxlib. Only use jax tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
};
# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
# CPU wheel is packaged.
# JAX project ships separate wheels for CPU, GPU, and TPU.
propagatedBuildInputs = [
absl-py
etils
@ -47,7 +53,7 @@ buildPythonPackage rec {
] ++ etils.optional-dependencies.epath;
nativeCheckInputs = [
jaxlib
jaxlib'
matplotlib
pytestCheckHook
pytest-xdist
@ -83,6 +89,11 @@ buildPythonPackage rec {
"test_custom_linear_solve_cholesky"
"test_custom_root_with_aux"
"testEigvalsGrad_shape"
] ++ lib.optionals (stdenv.isAarch64 && stdenv.isDarwin) [
# See https://github.com/google/jax/issues/14793.
"test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop"
"testQdwhWithRandomMatrix3"
"testScanGrad_jit_scan"
];
# See https://github.com/google/jax/issues/11722. This is a temporary fix in

View file

@ -39,7 +39,7 @@ assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";
let
version = "0.3.22";
version = "0.4.4";
pythonVersion = python.pythonVersion;
@ -50,21 +50,21 @@ let
cpuSrcs = {
"x86_64-linux" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-w2wo0jk+1BdEkNwfSZRQbebdI4Ac8Kgn0MB0cIMcWU4=";
hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ=";
};
"aarch64-darwin" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
hash = "sha256-7Ir55ZhBkccqfoa56WVBF8QwFAC2ws4KFHDkfVw6zm0=";
hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U=";
};
"x86_64-darwin" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl";
hash = "sha256-bOoQI+T+YsTUNA+cDu6wwYTcq9fyyzCpK9qrdCrNVoA=";
hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok=";
};
};
gpuSrc = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-rabU62p4fF7Tu/6t8LNYZdf6YO06jGry/JtyFZeamCs=";
hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk=";
};
in
buildPythonPackage rec {
@ -77,7 +77,13 @@ buildPythonPackage rec {
# python version.
disabled = !(pythonVersion == "3.10");
src = if !cudaSupport then cpuSrcs."${stdenv.hostPlatform.system}" else gpuSrc;
# See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
src =
if !cudaSupport then
(
cpuSrcs."${stdenv.hostPlatform.system}"
or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
) else gpuSrc;
# Prebuilt wheels are dynamically linked against things that nix can't find.
# Run `autoPatchelfHook` to automagically fix them.

View file

@ -52,7 +52,7 @@ let
inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
pname = "jaxlib";
version = "0.3.22";
version = "0.4.4";
meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@ -137,8 +137,9 @@ let
src = fetchFromGitHub {
owner = "google";
repo = "jax";
rev = "${pname}-v${version}";
hash = "sha256-bnczJ8ma/UMKhA5MUQ6H4az+Tj+By14ZTG6lQQwptQs=";
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
};
nativeBuildInputs = [
@ -242,9 +243,9 @@ let
sha256 =
if cudaSupport then
"sha256-4yu4y4SwSQoeaOz9yojhvCRGSC6jp61ycVDIKyIK/l8="
"sha256-cgsiloW77p4+TKRrYequZ/UwKwfO2jsHKtZ+aA30H7E="
else
"sha256-CyRfPfJc600M7VzR3/SQX/EAyeaXRJwDQWot5h2XnFU=";
"sha256-D7WYG3YUaWq+4APYx8WpA191VVtoHG0fth3uEHXOeos=";
};
buildAttrs = {