From f9588d69667011ce5e23ab5ca3c5f2c3ab636113 Mon Sep 17 00:00:00 2001 From: Samuel Ainsworth Date: Thu, 17 Mar 2022 21:08:55 +0000 Subject: [PATCH] python3Packages.{jaxlibWithCuda, jaxlib-bin}: add ptxas to $out/bin --- pkgs/development/python-modules/jaxlib/bin.nix | 12 +++++++++--- pkgs/development/python-modules/jaxlib/default.nix | 6 ++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index 7e6b00429df..0929831e32a 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -120,9 +120,15 @@ buildPythonPackage rec { done ''; - # pip dependencies and optionally cudatoolkit. Note that cudatoolkit is - # necessary since jaxlib looks for "ptxas" in $PATH. - propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11; + propagatedBuildInputs = [ absl-py flatbuffers scipy ]; + + # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. + # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for + # more info. + postInstall = lib.optional cudaSupport '' + mkdir -p $out/bin + ln -s ${cudatoolkit_11}/bin/ptxas $out/bin/ptxas + ''; pythonImportsCheck = [ "jaxlib" ]; diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index 664e109719a..363bfe56134 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -259,7 +259,13 @@ buildPythonPackage { src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl"; + # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH. + # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for + # more info. postInstall = lib.optionalString cudaSupport '' + mkdir -p $out/bin + ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas + find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do addOpenGLRunpath "$lib" patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"