python3Packages.jaxlib: pin cudaPackages to 11.6

This commit is contained in:
Samuel Ainsworth 2022-04-19 05:39:59 +00:00
parent 1344d5fe60
commit 92a001fa1c

View file

@ -4243,20 +4243,23 @@ in {
jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix {
cudaSupport = pkgs.config.cudaSupport or false;
inherit (self.tensorflow) cudaPackages;
# At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we
# pin to `cudaPackages_11_6` instead.
cudaPackages = pkgs.cudaPackages_11_6;
};
jaxlib-build = callPackage ../development/python-modules/jaxlib {
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
cudaSupport = pkgs.config.cudaSupport or false;
inherit (self.tensorflow) cudaPackages;
# At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we
# pin to `cudaPackages_11_6` instead.
cudaPackages = pkgs.cudaPackages_11_6;
};
jaxlib = self.jaxlib-build;
jaxlibWithCuda = self.jaxlib-build.override {
cudaSupport = true;
};
jaxlibWithoutCuda = self.jaxlib-build.override {