Merge pull request #164176 from samuela/samuela/jaxlib

python3Packages.jaxlib: cudatoolkit is necessary in propagatedBuildInputs when cudaSupport is enabled
This commit is contained in:
Samuel Ainsworth 2022-03-18 08:36:15 -07:00 committed by GitHub
commit 4585fc03be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 3 deletions

View file

@ -120,9 +120,15 @@ buildPythonPackage rec {
done done
''; '';
# pip dependencies and optionally cudatoolkit. Note that cudatoolkit is propagatedBuildInputs = [ absl-py flatbuffers scipy ];
# necessary since jaxlib looks for "ptxas" in $PATH.
propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11; # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
# more info.
postInstall = lib.optional cudaSupport ''
mkdir -p $out/bin
ln -s ${cudatoolkit_11}/bin/ptxas $out/bin/ptxas
'';
pythonImportsCheck = [ "jaxlib" ]; pythonImportsCheck = [ "jaxlib" ];

View file

@ -259,7 +259,13 @@ buildPythonPackage {
src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl"; src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl";
# Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
# more info.
postInstall = lib.optionalString cudaSupport '' postInstall = lib.optionalString cudaSupport ''
mkdir -p $out/bin
ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
addOpenGLRunpath "$lib" addOpenGLRunpath "$lib"
patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib" patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"