python3Packages.torchvision: added cudaSupport option (#132917)

Co-authored-by: Sandro <sandro.jaeckel@gmail.com>
This commit is contained in:
Alexander Kiselyov 2021-08-08 20:42:58 +03:00 committed by GitHub
parent 0d078fcdb2
commit 717538e908
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 3 deletions

View file

@ -301,6 +301,11 @@ in buildPythonPackage rec {
# Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
requiredSystemFeatures = [ "big-parallel" ];
passthru = {
inherit cudaSupport;
cudaArchList = final_cudaArchList;
};
meta = with lib; {
description = "Open source, prototype-to-production deep learning platform";
homepage = "https://pytorch.org/";

View file

@ -1,4 +1,5 @@
{ lib
, symlinkJoin
, buildPythonPackage
, fetchFromGitHub
, ninja
@ -10,9 +11,18 @@
, pillow
, pytorch
, pytest
, cudatoolkit
, cudnn
, cudaSupport ? pytorch.cudaSupport or false # by default uses the value from pytorch
}:
buildPythonPackage rec {
let
cudatoolkit_joined = symlinkJoin {
name = "${cudatoolkit.name}-unsplit";
paths = [ cudatoolkit.out cudatoolkit.lib ];
};
cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" pytorch.cudaArchList;
in buildPythonPackage rec {
pname = "torchvision";
version = "0.10.0";
@ -23,15 +33,22 @@ buildPythonPackage rec {
sha256 = "13j04ij0jmi58nhav1p69xrm8dg7jisg23268i3n6lnms37n02kc";
};
nativeBuildInputs = [ libpng ninja which ];
nativeBuildInputs = [ libpng ninja which ]
++ lib.optionals cudaSupport [ cudatoolkit_joined ];
TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/";
TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/";
buildInputs = [ libjpeg_turbo libpng ];
buildInputs = [ libjpeg_turbo libpng ]
++ lib.optionals cudaSupport [ cudnn ];
propagatedBuildInputs = [ numpy pillow pytorch scipy ];
preBuild = lib.optionalString cudaSupport ''
export TORCH_CUDA_ARCH_LIST="${cudaArchStr}"
export FORCE_CUDA=1
'';
# tries to download many datasets for tests
doCheck = false;
@ -45,6 +62,7 @@ buildPythonPackage rec {
description = "PyTorch vision library";
homepage = "https://pytorch.org/";
license = licenses.bsd3;
platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin;
maintainers = with maintainers; [ ericsagnes ];
};
}