diff --git a/pkgs/development/python-modules/pytorch/default.nix b/pkgs/development/python-modules/pytorch/default.nix index 0400a4773ebc..749a2c080962 100644 --- a/pkgs/development/python-modules/pytorch/default.nix +++ b/pkgs/development/python-modules/pytorch/default.nix @@ -1,11 +1,10 @@ -{ stdenv, fetchurl, fetchgit, fetchpatch, buildPythonPackage, python, pythonOlder, +{ stdenv, lib, fetchFromGitHub, fetchpatch, buildPythonPackage, python, cudaSupport ? false, cudatoolkit ? null, cudnn ? null, nccl ? null, magma ? null, mklDnnSupport ? true, useSystemNccl ? true, openMPISupport ? false, openmpi ? null, - buildBinaries ? false, buildDocs ? false, cudaArchList ? null, - fetchFromGitHub, lib, numpy, pyyaml, cffi, click, typing, cmake, hypothesis, numactl, psutil, + numpy, pyyaml, cffi, click, typing, cmake, dnnl, hypothesis, numactl, psutil, linkFarm, symlinkJoin, # virtual pkg that consistently instantiates blas across nixpkgs @@ -152,7 +151,15 @@ in buildPythonPackage rec { BUILD_NAMEDTENSOR = true; BUILD_DOCS = buildDocs; + + USE_MKL = blas.implementation == "mkl"; + + # Unlike MKL, MKLDNN is FOSS, so we enable support for it by default. Note + # that this was renamed to dnnl and then renamed again to oneDNN upstream, but + # pytorch still calls it by the old name mkldnn. USE_MKLDNN = mklDnnSupport; + USE_MKLDNN_CBLAS = mklDnnSupport; + preBuild = '' export MAX_JOBS=$NIX_BUILD_CORES ${python.interpreter} setup.py build --cmake-only @@ -174,7 +181,6 @@ in buildPythonPackage rec { done ''; - # Override the (weirdly) wrong version set by default. See # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267 @@ -199,7 +205,7 @@ in buildPythonPackage rec { ninja ] ++ lib.optionals cudaSupport [ cudatoolkit_joined ]; - buildInputs = [ blas ] + buildInputs = [ blas blas.provider dnnl ] ++ lib.optionals cudaSupport [ cudnn magma nccl ] ++ lib.optionals stdenv.isLinux [ numactl ]; @@ -214,10 +220,13 @@ in buildPythonPackage rec { checkInputs = [ hypothesis ninja psutil ]; - doCheck = false; # tests take a long time for channel release, so doCheck should be overridden only when developing + # Tests take a long time and may be flaky, so just sanity-check imports + doCheck = false; + pythonImportsCheck = [ + "torch" + ]; + checkPhase = with lib.versions; with lib.strings; concatStringsSep " " [ - # MKL 2019.5-only workaround. See: https://github.com/NixOS/nixpkgs/issues/75611 - (optionalString (blas.implementation == "mkl" && majorMinor blas.version == "2019.5") "KMP_INIT_AT_FORK=FALSE ") cudaStubEnv "${python.interpreter} test/run_test.py" "--exclude"