Merge pull request #225661 from SomeoneSerge/jax-libstdcxx
python3Packages.jax: fix libstdc++ mismatch when built with CUDA
This commit is contained in:
commit
929a328dd9
|
@ -18,8 +18,14 @@ final: prev: let
|
|||
# E.g. for cudaPackages_11_8 we use gcc11 with gcc12's libstdc++
|
||||
# Cf. https://github.com/NixOS/nixpkgs/pull/218265 for context
|
||||
backendStdenv = final.callPackage ./stdenv.nix {
|
||||
nixpkgsStdenv = prev.pkgs.stdenv;
|
||||
nvccCompatibleStdenv = prev.pkgs.buildPackages."${finalVersion.gcc}Stdenv";
|
||||
# We use buildPackages (= pkgsBuildHost) because we look for a gcc that
|
||||
# runs on our build platform, and that produces executables for the host
|
||||
# platform (= platform on which we deploy and run the downstream packages).
|
||||
# The target platform of buildPackages.gcc is our host platform, so its
|
||||
# .lib output should be the libstdc++ we want to be writing in the runpaths
|
||||
# Cf. https://github.com/NixOS/nixpkgs/pull/225661#discussion_r1164564576
|
||||
nixpkgsCompatibleLibstdcxx = final.pkgs.buildPackages.gcc.cc.lib;
|
||||
nvccCompatibleCC = final.pkgs.buildPackages."${finalVersion.gcc}".cc;
|
||||
};
|
||||
|
||||
### Add classic cudatoolkit package
|
||||
|
|
|
@ -1,17 +1,33 @@
|
|||
{ nixpkgsStdenv
|
||||
, nvccCompatibleStdenv
|
||||
{ lib
|
||||
, nixpkgsCompatibleLibstdcxx
|
||||
, nvccCompatibleCC
|
||||
, overrideCC
|
||||
, stdenv
|
||||
, wrapCCWith
|
||||
}:
|
||||
|
||||
overrideCC nixpkgsStdenv (wrapCCWith {
|
||||
cc = nvccCompatibleStdenv.cc.cc;
|
||||
let
|
||||
cc = wrapCCWith
|
||||
{
|
||||
cc = nvccCompatibleCC;
|
||||
|
||||
# This option is for clang's libcxx, but we (ab)use it for gcc's libstdc++.
|
||||
# Note that libstdc++ maintains forward-compatibility: if we load a newer
|
||||
# libstdc++ into the process, we can still use libraries built against an
|
||||
# older libstdc++. This, in practice, means that we should use libstdc++ from
|
||||
# the same stdenv that the rest of nixpkgs uses.
|
||||
# We currently do not try to support anything other than gcc and linux.
|
||||
libcxx = nixpkgsCompatibleLibstdcxx;
|
||||
};
|
||||
cudaStdenv = overrideCC stdenv cc;
|
||||
passthruExtra = {
|
||||
inherit nixpkgsCompatibleLibstdcxx;
|
||||
# cc already exposed
|
||||
};
|
||||
assertCondition = true;
|
||||
in
|
||||
lib.extendDerivation
|
||||
assertCondition
|
||||
passthruExtra
|
||||
cudaStdenv
|
||||
|
||||
# This option is for clang's libcxx, but we (ab)use it for gcc's libstdc++.
|
||||
# Note that libstdc++ maintains forward-compatibility: if we load a newer
|
||||
# libstdc++ into the process, we can still use libraries built against an
|
||||
# older libstdc++. This, in practice, means that we should use libstdc++ from
|
||||
# the same stdenv that the rest of nixpkgs uses.
|
||||
# We currently do not try to support anything other than gcc and linux.
|
||||
libcxx = nixpkgsStdenv.cc.cc.lib;
|
||||
})
|
||||
|
|
|
@ -49,7 +49,7 @@
|
|||
}:
|
||||
|
||||
let
|
||||
inherit (cudaPackages) cudatoolkit cudaFlags cudnn nccl;
|
||||
inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
|
||||
|
||||
pname = "jaxlib";
|
||||
version = "0.3.22";
|
||||
|
@ -81,7 +81,7 @@ let
|
|||
cudatoolkit_cc_joined = symlinkJoin {
|
||||
name = "${cudatoolkit.cc.name}-merged";
|
||||
paths = [
|
||||
cudatoolkit.cc
|
||||
backendStdenv.cc
|
||||
binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
|
||||
];
|
||||
};
|
||||
|
@ -271,6 +271,7 @@ let
|
|||
sed -i 's@include/pybind11@pybind11@g' $src
|
||||
done
|
||||
'' + lib.optionalString cudaSupport ''
|
||||
export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
|
||||
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
'' + lib.optionalString stdenv.isDarwin ''
|
||||
# Framework search paths aren't added by bintools hook
|
||||
|
|
Loading…
Reference in a new issue