Merge pull request #225661 from SomeoneSerge/jax-libstdcxx

python3Packages.jax: fix libstdc++ mismatch when built with CUDA
This commit is contained in:
Samuel Ainsworth 2023-04-13 12:28:13 -04:00 committed by GitHub
commit 929a328dd9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 16 deletions

View file

@ -18,8 +18,14 @@ final: prev: let
# E.g. for cudaPackages_11_8 we use gcc11 with gcc12's libstdc++
# Cf. https://github.com/NixOS/nixpkgs/pull/218265 for context
backendStdenv = final.callPackage ./stdenv.nix {
nixpkgsStdenv = prev.pkgs.stdenv;
nvccCompatibleStdenv = prev.pkgs.buildPackages."${finalVersion.gcc}Stdenv";
# We use buildPackages (= pkgsBuildHost) because we look for a gcc that
# runs on our build platform, and that produces executables for the host
# platform (= platform on which we deploy and run the downstream packages).
# The target platform of buildPackages.gcc is our host platform, so its
# .lib output should be the libstdc++ we want to be writing in the runpaths
# Cf. https://github.com/NixOS/nixpkgs/pull/225661#discussion_r1164564576
nixpkgsCompatibleLibstdcxx = final.pkgs.buildPackages.gcc.cc.lib;
nvccCompatibleCC = final.pkgs.buildPackages."${finalVersion.gcc}".cc;
};
### Add classic cudatoolkit package

View file

@ -1,17 +1,33 @@
{ nixpkgsStdenv
, nvccCompatibleStdenv
{ lib
, nixpkgsCompatibleLibstdcxx
, nvccCompatibleCC
, overrideCC
, stdenv
, wrapCCWith
}:
overrideCC nixpkgsStdenv (wrapCCWith {
cc = nvccCompatibleStdenv.cc.cc;
let
cc = wrapCCWith
{
cc = nvccCompatibleCC;
# This option is for clang's libcxx, but we (ab)use it for gcc's libstdc++.
# Note that libstdc++ maintains forward-compatibility: if we load a newer
# libstdc++ into the process, we can still use libraries built against an
# older libstdc++. This, in practice, means that we should use libstdc++ from
# the same stdenv that the rest of nixpkgs uses.
# We currently do not try to support anything other than gcc and linux.
libcxx = nixpkgsCompatibleLibstdcxx;
};
cudaStdenv = overrideCC stdenv cc;
passthruExtra = {
inherit nixpkgsCompatibleLibstdcxx;
# cc already exposed
};
assertCondition = true;
in
lib.extendDerivation
assertCondition
passthruExtra
cudaStdenv
# This option is for clang's libcxx, but we (ab)use it for gcc's libstdc++.
# Note that libstdc++ maintains forward-compatibility: if we load a newer
# libstdc++ into the process, we can still use libraries built against an
# older libstdc++. This, in practice, means that we should use libstdc++ from
# the same stdenv that the rest of nixpkgs uses.
# We currently do not try to support anything other than gcc and linux.
libcxx = nixpkgsStdenv.cc.cc.lib;
})

View file

@ -49,7 +49,7 @@
}:
let
inherit (cudaPackages) cudatoolkit cudaFlags cudnn nccl;
inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
pname = "jaxlib";
version = "0.3.22";
@ -81,7 +81,7 @@ let
cudatoolkit_cc_joined = symlinkJoin {
name = "${cudatoolkit.cc.name}-merged";
paths = [
cudatoolkit.cc
backendStdenv.cc
binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
];
};
@ -271,6 +271,7 @@ let
sed -i 's@include/pybind11@pybind11@g' $src
done
'' + lib.optionalString cudaSupport ''
export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
'' + lib.optionalString stdenv.isDarwin ''
# Framework search paths aren't added by bintools hook