Merge pull request #246712 from NickCao/jax-rework
python3Packages.{jax,jaxlib}: update to 0.4.14
This commit is contained in:
commit
b98f6d9072
7 changed files with 216 additions and 117 deletions
|
@ -10,9 +10,12 @@ args@{
|
|||
, bazelFlags ? []
|
||||
, bazelBuildFlags ? []
|
||||
, bazelTestFlags ? []
|
||||
, bazelRunFlags ? []
|
||||
, runTargetFlags ? []
|
||||
, bazelFetchFlags ? []
|
||||
, bazelTargets
|
||||
, bazelTargets ? []
|
||||
, bazelTestTargets ? []
|
||||
, bazelRunTarget ? null
|
||||
, buildAttrs
|
||||
, fetchAttrs
|
||||
|
||||
|
@ -46,17 +49,23 @@ args@{
|
|||
|
||||
let
|
||||
fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // {
|
||||
name = name;
|
||||
bazelFlags = bazelFlags;
|
||||
bazelBuildFlags = bazelBuildFlags;
|
||||
bazelTestFlags = bazelTestFlags;
|
||||
bazelFetchFlags = bazelFetchFlags;
|
||||
bazelTestTargets = bazelTestTargets;
|
||||
dontAddBazelOpts = dontAddBazelOpts;
|
||||
inherit
|
||||
name
|
||||
bazelFlags
|
||||
bazelBuildFlags
|
||||
bazelTestFlags
|
||||
bazelRunFlags
|
||||
runTargetFlags
|
||||
bazelFetchFlags
|
||||
bazelTargets
|
||||
bazelTestTargets
|
||||
bazelRunTarget
|
||||
dontAddBazelOpts
|
||||
;
|
||||
};
|
||||
fBuildAttrs = fArgs // buildAttrs;
|
||||
fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ];
|
||||
bazelCmd = { cmd, additionalFlags, targets }:
|
||||
bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }:
|
||||
lib.optionalString (targets != [ ]) ''
|
||||
# See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables]
|
||||
BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \
|
||||
|
@ -73,7 +82,8 @@ let
|
|||
"''${host_linkopts[@]}" \
|
||||
$bazelFlags \
|
||||
${lib.strings.concatStringsSep " " additionalFlags} \
|
||||
${lib.strings.concatStringsSep " " targets}
|
||||
${lib.strings.concatStringsSep " " targets} \
|
||||
${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags}
|
||||
'';
|
||||
# we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so:
|
||||
# chmod: cannot operate on dangling symlink '$symlink'
|
||||
|
@ -262,6 +272,15 @@ stdenv.mkDerivation (fBuildAttrs // {
|
|||
targets = fBuildAttrs.bazelTargets;
|
||||
}
|
||||
}
|
||||
${
|
||||
bazelCmd {
|
||||
cmd = "run";
|
||||
additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ];
|
||||
# Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list.
|
||||
targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ];
|
||||
targetRunFlags = fBuildAttrs.runTargetFlags;
|
||||
}
|
||||
}
|
||||
runHook postBuild
|
||||
'';
|
||||
})
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
{ lib
|
||||
, absl-py
|
||||
, blas
|
||||
, buildPythonPackage
|
||||
, etils
|
||||
, setuptools
|
||||
, importlib-metadata
|
||||
, fetchFromGitHub
|
||||
, jaxlib
|
||||
, jaxlib-bin
|
||||
, lapack
|
||||
, matplotlib
|
||||
, ml-dtypes
|
||||
, numpy
|
||||
, opt-einsum
|
||||
, pytestCheckHook
|
||||
|
@ -15,7 +16,6 @@
|
|||
, pythonOlder
|
||||
, scipy
|
||||
, stdenv
|
||||
, typing-extensions
|
||||
}:
|
||||
|
||||
let
|
||||
|
@ -27,30 +27,32 @@ let
|
|||
in
|
||||
buildPythonPackage rec {
|
||||
pname = "jax";
|
||||
version = "0.4.5";
|
||||
format = "setuptools";
|
||||
version = "0.4.14";
|
||||
format = "pyproject";
|
||||
|
||||
disabled = pythonOlder "3.7";
|
||||
disabled = pythonOlder "3.9";
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "google";
|
||||
repo = pname;
|
||||
# google/jax contains tags for jax and jaxlib. Only use jax tags!
|
||||
rev = "refs/tags/${pname}-v${version}";
|
||||
hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
|
||||
hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
setuptools
|
||||
];
|
||||
|
||||
# jaxlib is _not_ included in propagatedBuildInputs because there are
|
||||
# different versions of jaxlib depending on the desired target hardware. The
|
||||
# JAX project ships separate wheels for CPU, GPU, and TPU.
|
||||
propagatedBuildInputs = [
|
||||
absl-py
|
||||
etils
|
||||
ml-dtypes
|
||||
numpy
|
||||
opt-einsum
|
||||
scipy
|
||||
typing-extensions
|
||||
] ++ etils.optional-dependencies.epath;
|
||||
] ++ lib.optional (pythonOlder "3.10") importlib-metadata;
|
||||
|
||||
nativeCheckInputs = [
|
||||
jaxlib'
|
||||
|
@ -96,24 +98,12 @@ buildPythonPackage rec {
|
|||
"testScanGrad_jit_scan"
|
||||
];
|
||||
|
||||
# See https://github.com/google/jax/issues/11722. This is a temporary fix in
|
||||
# order to unblock etils, and upgrading jax/jaxlib to the latest version. See
|
||||
# https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
|
||||
disabledTestPaths = [
|
||||
"tests/api_test.py"
|
||||
"tests/core_test.py"
|
||||
"tests/lax_numpy_indexing_test.py"
|
||||
"tests/lax_numpy_test.py"
|
||||
"tests/nn_test.py"
|
||||
"tests/random_test.py"
|
||||
"tests/sparse_test.py"
|
||||
] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
|
||||
disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
|
||||
# RuntimeWarning: invalid value encountered in cast
|
||||
"tests/lax_test.py"
|
||||
];
|
||||
|
||||
# As of 0.3.22, `import jax` does not work without jaxlib being installed.
|
||||
pythonImportsCheck = [ ];
|
||||
pythonImportsCheck = [ "jax" ];
|
||||
|
||||
meta = with lib; {
|
||||
description = "Differentiate, compile, and transform Numpy code";
|
||||
|
|
|
@ -18,11 +18,12 @@
|
|||
, autoPatchelfHook
|
||||
, buildPythonPackage
|
||||
, config
|
||||
, cudnn ? cudaPackages.cudnn
|
||||
, fetchPypi
|
||||
, fetchurl
|
||||
, flatbuffers
|
||||
, isPy39
|
||||
, jaxlib-build
|
||||
, lib
|
||||
, ml-dtypes
|
||||
, python
|
||||
, scipy
|
||||
, stdenv
|
||||
|
@ -35,46 +36,57 @@ let
|
|||
inherit (cudaPackages) cudatoolkit cudnn;
|
||||
in
|
||||
|
||||
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
|
||||
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";
|
||||
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux;
|
||||
|
||||
let
|
||||
version = "0.4.4";
|
||||
version = "0.4.14";
|
||||
|
||||
inherit (python) pythonVersion;
|
||||
|
||||
# As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
|
||||
# official instructions recommend installing CPU-only versions via PyPI.
|
||||
cpuSrcs =
|
||||
let
|
||||
getSrcFromPypi = { platform, hash }: fetchPypi {
|
||||
inherit version platform hash;
|
||||
pname = "jaxlib";
|
||||
format = "wheel";
|
||||
# See the `disabled` attr comment below.
|
||||
dist = "cp310";
|
||||
python = "cp310";
|
||||
abi = "cp310";
|
||||
};
|
||||
in
|
||||
{
|
||||
"x86_64-linux" = getSrcFromPypi {
|
||||
platform = "manylinux2014_x86_64";
|
||||
hash = "sha256-nyylSZfqHeftlvVgJZFCN1ldjluZVJIYu4ZSsVxvXf8=";
|
||||
};
|
||||
"aarch64-darwin" = getSrcFromPypi {
|
||||
platform = "macosx_11_0_arm64";
|
||||
hash = "sha256-La3wYbGCjWTl7krBD6BaBRqyBD8R530Lckbz0AWv0FM=";
|
||||
};
|
||||
"x86_64-darwin" = getSrcFromPypi {
|
||||
platform = "macosx_10_14_x86_64";
|
||||
hash = "sha256-hDg5+qisgtgOrdvbjxsUgI73cW6Aah8NLjhPe4kMAsM=";
|
||||
};
|
||||
};
|
||||
|
||||
pythonVersion = python.pythonVersion;
|
||||
|
||||
# Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
|
||||
# When upgrading, you can get these hashes from prefetch.sh. See
|
||||
# https://github.com/google/jax/issues/12879 as to why this specific URL is
|
||||
# the correct index.
|
||||
cpuSrcs = {
|
||||
"x86_64-linux" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ=";
|
||||
};
|
||||
"aarch64-darwin" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
|
||||
hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U=";
|
||||
};
|
||||
"x86_64-darwin" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl";
|
||||
hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok=";
|
||||
};
|
||||
# https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
|
||||
gpuSrc = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-CcQ5kjp4XfUX4/RwFY3T5G3kVKAeyoCTXu1Lo4O16Qo=";
|
||||
};
|
||||
|
||||
gpuSrc = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk=";
|
||||
};
|
||||
in
|
||||
buildPythonPackage rec {
|
||||
buildPythonPackage {
|
||||
pname = "jaxlib";
|
||||
inherit version;
|
||||
format = "wheel";
|
||||
|
||||
# At the time of writing (2022-10-19), there are releases for <=3.10.
|
||||
# Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs
|
||||
# python version.
|
||||
disabled = !(pythonVersion == "3.10");
|
||||
|
||||
# See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
|
||||
|
@ -87,9 +99,10 @@ buildPythonPackage rec {
|
|||
|
||||
# Prebuilt wheels are dynamically linked against things that nix can't find.
|
||||
# Run `autoPatchelfHook` to automagically fix them.
|
||||
nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ];
|
||||
nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ]
|
||||
++ lib.optionals cudaSupport [ addOpenGLRunpath ];
|
||||
# Dynamic link dependencies
|
||||
buildInputs = [ stdenv.cc.cc ];
|
||||
buildInputs = [ stdenv.cc.cc.lib ];
|
||||
|
||||
# jaxlib contains shared libraries that open other shared libraries via dlopen
|
||||
# and these implicit dependencies are not recognized by ldd or
|
||||
|
@ -113,7 +126,12 @@ buildPythonPackage rec {
|
|||
done
|
||||
'';
|
||||
|
||||
propagatedBuildInputs = [ absl-py flatbuffers scipy ];
|
||||
propagatedBuildInputs = [
|
||||
absl-py
|
||||
flatbuffers
|
||||
ml-dtypes
|
||||
scipy
|
||||
];
|
||||
|
||||
# Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
|
||||
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
|
||||
|
@ -123,7 +141,7 @@ buildPythonPackage rec {
|
|||
ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
|
||||
'';
|
||||
|
||||
pythonImportsCheck = [ "jaxlib" ];
|
||||
inherit (jaxlib-build) pythonImportsCheck;
|
||||
|
||||
meta = with lib; {
|
||||
description = "XLA library for JAX";
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
# Build-time dependencies:
|
||||
, addOpenGLRunpath
|
||||
, bazel_5
|
||||
, bazel_6
|
||||
, binutils
|
||||
, buildBazelPackage
|
||||
, buildPythonPackage
|
||||
|
@ -21,11 +21,13 @@
|
|||
, setuptools
|
||||
, symlinkJoin
|
||||
, wheel
|
||||
, build
|
||||
, which
|
||||
|
||||
# Python dependencies:
|
||||
, absl-py
|
||||
, flatbuffers
|
||||
, ml-dtypes
|
||||
, numpy
|
||||
, scipy
|
||||
, six
|
||||
|
@ -35,7 +37,6 @@
|
|||
, giflib
|
||||
, grpc
|
||||
, libjpeg_turbo
|
||||
, protobuf
|
||||
, python
|
||||
, snappy
|
||||
, zlib
|
||||
|
@ -53,7 +54,7 @@ let
|
|||
inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
|
||||
|
||||
pname = "jaxlib";
|
||||
version = "0.4.4";
|
||||
version = "0.4.14";
|
||||
|
||||
meta = with lib; {
|
||||
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
|
||||
|
@ -99,7 +100,9 @@ let
|
|||
# "com_github_googleapis_googleapis"
|
||||
# "com_github_googlecloudplatform_google_cloud_cpp"
|
||||
"com_github_grpc_grpc"
|
||||
"com_google_protobuf"
|
||||
# ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
|
||||
# target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
|
||||
# "com_google_protobuf"
|
||||
# Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
|
||||
# "com_googlesource_code_re2"
|
||||
"curl"
|
||||
|
@ -120,7 +123,9 @@ let
|
|||
"org_sqlite"
|
||||
"pasta"
|
||||
"png"
|
||||
"pybind11"
|
||||
# ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
|
||||
# target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
|
||||
# "pybind11"
|
||||
"six_archive"
|
||||
"snappy"
|
||||
"tblib_archive"
|
||||
|
@ -138,14 +143,15 @@ let
|
|||
bazel-build = buildBazelPackage rec {
|
||||
name = "bazel-build-${pname}-${version}";
|
||||
|
||||
bazel = bazel_5;
|
||||
# See https://github.com/google/jax/blob/main/.bazelversion for the latest.
|
||||
bazel = bazel_6;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "google";
|
||||
repo = "jax";
|
||||
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
|
||||
rev = "refs/tags/${pname}-v${version}";
|
||||
hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
|
||||
hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
|
@ -154,6 +160,7 @@ let
|
|||
git
|
||||
setuptools
|
||||
wheel
|
||||
build
|
||||
which
|
||||
] ++ lib.optionals stdenv.isDarwin [
|
||||
cctools
|
||||
|
@ -169,7 +176,7 @@ let
|
|||
numpy
|
||||
openssl
|
||||
pkgs.flatbuffers
|
||||
protobuf
|
||||
pkgs.protobuf
|
||||
pybind11
|
||||
scipy
|
||||
six
|
||||
|
@ -188,7 +195,8 @@ let
|
|||
rm -f .bazelversion
|
||||
'';
|
||||
|
||||
bazelTargets = [ "//build:build_wheel" ];
|
||||
bazelRunTarget = "//jaxlib/tools:build_wheel";
|
||||
runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ];
|
||||
|
||||
removeRulesCC = false;
|
||||
|
||||
|
@ -207,7 +215,11 @@ let
|
|||
build --action_env=PYENV_ROOT
|
||||
build --python_path="${python}/bin/python"
|
||||
build --distinct_host_configuration=false
|
||||
build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
|
||||
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
|
||||
'' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) ''
|
||||
build --config=avx_posix
|
||||
'' + lib.optionalString mklSupport ''
|
||||
build --config=mkl_open_source_only
|
||||
'' + lib.optionalString cudaSupport ''
|
||||
build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
|
||||
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
|
||||
|
@ -234,7 +246,7 @@ let
|
|||
fetchAttrs = {
|
||||
TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
|
||||
# we have to force @mkl_dnn_v1 since it's not needed on darwin
|
||||
bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
|
||||
bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ];
|
||||
bazelFlags = bazelFlags ++ [
|
||||
"--config=avx_posix"
|
||||
] ++ lib.optionals cudaSupport [
|
||||
|
@ -247,11 +259,12 @@ let
|
|||
"--config=mkl_open_source_only"
|
||||
];
|
||||
|
||||
sha256 =
|
||||
if cudaSupport then
|
||||
"sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk="
|
||||
else
|
||||
"sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI=";
|
||||
sha256 = (if cudaSupport then {
|
||||
x86_64-linux = "sha256-8QaXoZq6oITRsYn4RdLUXcKQv3PJ4Q3ItX9PkBwxGBI=";
|
||||
} else {
|
||||
x86_64-linux = "sha256-M/h5EZmyiV4QvzgKRjdz7V1LHENUJlc/ig1QAItnWVQ=";
|
||||
aarch64-linux = "sha256-edkYcdlvOLNGRSanch1fGCZwq8SFn3TzcUNt1LhzG/E=";
|
||||
}).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}");
|
||||
};
|
||||
|
||||
buildAttrs = {
|
||||
|
@ -261,25 +274,13 @@ let
|
|||
"nsync" # fails to build on darwin
|
||||
]);
|
||||
|
||||
bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
|
||||
"--config=avx_posix"
|
||||
] ++ lib.optionals cudaSupport [
|
||||
"--config=cuda"
|
||||
] ++ lib.optionals mklSupport [
|
||||
"--config=mkl_open_source_only"
|
||||
];
|
||||
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
|
||||
# 1) Fix pybind11 include paths.
|
||||
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
|
||||
# 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
|
||||
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
|
||||
# 3) Patch python path in the compiler driver.
|
||||
preBuild = ''
|
||||
for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
|
||||
sed -i 's@include/pybind11@pybind11@g' $src
|
||||
done
|
||||
'' + lib.optionalString cudaSupport ''
|
||||
# 2) Patch python path in the compiler driver.
|
||||
preBuild = lib.optionalString cudaSupport ''
|
||||
export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
|
||||
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
'' + lib.optionalString stdenv.isDarwin ''
|
||||
# Framework search paths aren't added by bintools hook
|
||||
# https://github.com/NixOS/nixpkgs/pull/41914
|
||||
|
@ -289,16 +290,12 @@ let
|
|||
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
|
||||
--replace "/usr/bin/libtool" "${cctools}/bin/libtool"
|
||||
'' + (if stdenv.cc.isGNU then ''
|
||||
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
|
||||
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
|
||||
'' else if stdenv.cc.isClang then ''
|
||||
sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
|
||||
sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
|
||||
'' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
|
||||
|
||||
installPhase = ''
|
||||
./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch}
|
||||
'';
|
||||
};
|
||||
|
||||
inherit meta;
|
||||
|
@ -345,13 +342,19 @@ buildPythonPackage {
|
|||
grpc
|
||||
jsoncpp
|
||||
libjpeg_turbo
|
||||
ml-dtypes
|
||||
numpy
|
||||
scipy
|
||||
six
|
||||
snappy
|
||||
];
|
||||
|
||||
pythonImportsCheck = [ "jaxlib" ];
|
||||
pythonImportsCheck = [
|
||||
"jaxlib"
|
||||
# `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
|
||||
"jaxlib.cpu_feature_guard"
|
||||
"jaxlib.xla_client"
|
||||
];
|
||||
|
||||
# Without it there are complaints about libcudart.so.11.0 not being found
|
||||
# because RPATH path entries added above are stripped.
|
||||
|
|
|
@ -1,7 +1,15 @@
|
|||
version="$1"
|
||||
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)"
|
||||
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)"
|
||||
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)"
|
||||
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)"
|
||||
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)"
|
||||
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)"
|
||||
#!/usr/bin/env bash
|
||||
|
||||
prefetch () {
|
||||
expr="(import <nixpkgs> { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url"
|
||||
url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r)
|
||||
echo "$url"
|
||||
sha256=$(nix-prefetch-url "$url")
|
||||
nix hash to-sri --type sha256 "$sha256"
|
||||
echo
|
||||
}
|
||||
|
||||
prefetch "x86_64-linux" "false"
|
||||
prefetch "aarch64-darwin" "false"
|
||||
prefetch "x86_64-darwin" "false"
|
||||
prefetch "x86_64-linux" "true"
|
||||
|
|
60
pkgs/development/python-modules/ml-dtypes/default.nix
Normal file
60
pkgs/development/python-modules/ml-dtypes/default.nix
Normal file
|
@ -0,0 +1,60 @@
|
|||
{ lib
|
||||
, buildPythonPackage
|
||||
, pythonOlder
|
||||
, fetchFromGitHub
|
||||
, setuptools
|
||||
, pybind11
|
||||
, numpy
|
||||
, pytestCheckHook
|
||||
, absl-py
|
||||
}:
|
||||
|
||||
buildPythonPackage rec {
|
||||
pname = "ml-dtypes";
|
||||
version = "0.2.0";
|
||||
format = "pyproject";
|
||||
|
||||
disabled = pythonOlder "3.7";
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "jax-ml";
|
||||
repo = "ml_dtypes";
|
||||
rev = "refs/tags/v${version}";
|
||||
hash = "sha256-eqajWUwylIYsS8gzEaCZLLr+1+34LXWhfKBjuwsEhhI=";
|
||||
# Since this upstream patch (https://github.com/jax-ml/ml_dtypes/commit/1bfd097e794413b0d465fa34f2eff0f3828ff521),
|
||||
# the attempts to use the nixpkgs packaged eigen dependency have failed.
|
||||
# Hence, we rely on the bundled eigen library.
|
||||
fetchSubmodules = true;
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
setuptools
|
||||
pybind11
|
||||
];
|
||||
|
||||
propagatedBuildInputs = [
|
||||
numpy
|
||||
];
|
||||
|
||||
nativeCheckInputs = [
|
||||
pytestCheckHook
|
||||
absl-py
|
||||
];
|
||||
|
||||
preCheck = ''
|
||||
# remove src module, so tests use the installed module instead
|
||||
mv ./ml_dtypes/tests ./tests
|
||||
rm -rf ./ml_dtypes
|
||||
'';
|
||||
|
||||
pythonImportsCheck = [
|
||||
"ml_dtypes"
|
||||
];
|
||||
|
||||
meta = with lib; {
|
||||
description = "A stand-alone implementation of several NumPy dtype extensions used in machine learning libraries";
|
||||
homepage = "https://github.com/jax-ml/ml_dtypes";
|
||||
license = licenses.asl20;
|
||||
maintainers = with maintainers; [ GaetanLepage samuela ];
|
||||
};
|
||||
}
|
|
@ -5320,7 +5320,6 @@ self: super: with self; {
|
|||
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
|
||||
inherit (pkgs.config) cudaSupport;
|
||||
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
|
||||
protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21
|
||||
};
|
||||
|
||||
jaxlib = self.jaxlib-build;
|
||||
|
@ -6573,6 +6572,8 @@ self: super: with self; {
|
|||
|
||||
ml-collections = callPackage ../development/python-modules/ml-collections { };
|
||||
|
||||
ml-dtypes = callPackage ../development/python-modules/ml-dtypes { };
|
||||
|
||||
mlflow = callPackage ../development/python-modules/mlflow { };
|
||||
|
||||
mlrose = callPackage ../development/python-modules/mlrose { };
|
||||
|
|
Loading…
Reference in a new issue