From 6038b56de8a800efcc2ad1ce909c62faeca8b050 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Sun, 14 Feb 2021 08:18:02 +0100 Subject: [PATCH] python3Packages.pytorch: add compute capabilities for CUDA 11 CUDA 11 supports capabilities 8.0 and 8.6. The change adds these capabilities when CUDA 11 is used, enabling support for Ampere GPUs. --- .../python-modules/pytorch/default.nix | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/pkgs/development/python-modules/pytorch/default.nix b/pkgs/development/python-modules/pytorch/default.nix index db1914f4ee7..e82054c2885 100644 --- a/pkgs/development/python-modules/pytorch/default.nix +++ b/pkgs/development/python-modules/pytorch/default.nix @@ -74,27 +74,35 @@ let # (allowing FBGEMM to be built in pytorch-1.1), and may future proof this # derivation. brokenArchs = [ "3.0" ]; # this variable is only used as documentation. - cuda9ArchList = [ - "3.5" - "5.0" - "5.2" - "6.0" - "6.1" - "7.0" - "7.0+PTX" # I am getting a "undefined architecture compute_75" on cuda 9 - # which leads me to believe this is the final cuda-9-compatible architecture. - ]; - cuda10ArchList = cuda9ArchList ++ [ - "7.5" - "7.5+PTX" # < most recent architecture as of cudatoolkit_10_0 and pytorch-1.2.0 - ]; + + cudaCapabilities = rec { + cuda9 = [ + "3.5" + "5.0" + "5.2" + "6.0" + "6.1" + "7.0" + "7.0+PTX" # I am getting a "undefined architecture compute_75" on cuda 9 + # which leads me to believe this is the final cuda-9-compatible architecture. + ]; + + cuda10 = cuda9 ++ [ + "7.5" + "7.5+PTX" # < most recent architecture as of cudatoolkit_10_0 and pytorch-1.2.0 + ]; + + cuda11 = cuda10 ++ [ + "8.0" + "8.0+PTX" # < CUDA toolkit 11.0 + "8.6" + "8.6+PTX" # < CUDA toolkit 11.1 + ]; + }; final_cudaArchList = if !cudaSupport || cudaArchList != null then cudaArchList - else - if lib.versions.major cudatoolkit.version == "9" - then cuda9ArchList - else cuda10ArchList; # the assert above removes any ambiguity here. + else cudaCapabilities."cuda${lib.versions.major cudatoolkit.version}"; # Normally libcuda.so.1 is provided at runtime by nvidia-x11 via # LD_LIBRARY_PATH=/run/opengl-driver/lib. We only use the stub