python3Packages.pytorch: add compute capabilities for CUDA 11

CUDA 11 supports capabilities 8.0 and 8.6. The change adds these
capabilities when CUDA 11 is used, enabling support for Ampere GPUs.
This commit is contained in:
Daniël de Kok 2021-02-14 08:18:02 +01:00
parent 406c33bb3d
commit 6038b56de8

View file

@ -74,27 +74,35 @@ let
# (allowing FBGEMM to be built in pytorch-1.1), and may future proof this
# derivation.
brokenArchs = [ "3.0" ]; # this variable is only used as documentation.
cuda9ArchList = [
"3.5"
"5.0"
"5.2"
"6.0"
"6.1"
"7.0"
"7.0+PTX" # I am getting a "undefined architecture compute_75" on cuda 9
# which leads me to believe this is the final cuda-9-compatible architecture.
];
cuda10ArchList = cuda9ArchList ++ [
"7.5"
"7.5+PTX" # < most recent architecture as of cudatoolkit_10_0 and pytorch-1.2.0
];
cudaCapabilities = rec {
cuda9 = [
"3.5"
"5.0"
"5.2"
"6.0"
"6.1"
"7.0"
"7.0+PTX" # I am getting a "undefined architecture compute_75" on cuda 9
# which leads me to believe this is the final cuda-9-compatible architecture.
];
cuda10 = cuda9 ++ [
"7.5"
"7.5+PTX" # < most recent architecture as of cudatoolkit_10_0 and pytorch-1.2.0
];
cuda11 = cuda10 ++ [
"8.0"
"8.0+PTX" # < CUDA toolkit 11.0
"8.6"
"8.6+PTX" # < CUDA toolkit 11.1
];
};
final_cudaArchList =
if !cudaSupport || cudaArchList != null
then cudaArchList
else
if lib.versions.major cudatoolkit.version == "9"
then cuda9ArchList
else cuda10ArchList; # the assert above removes any ambiguity here.
else cudaCapabilities."cuda${lib.versions.major cudatoolkit.version}";
# Normally libcuda.so.1 is provided at runtime by nvidia-x11 via
# LD_LIBRARY_PATH=/run/opengl-driver/lib. We only use the stub