diff --git a/pkgs/development/python-modules/accelerate/default.nix b/pkgs/development/python-modules/accelerate/default.nix index 4de49a21a83..9aefa229487 100644 --- a/pkgs/development/python-modules/accelerate/default.nix +++ b/pkgs/development/python-modules/accelerate/default.nix @@ -2,6 +2,8 @@ , lib , buildPythonPackage , fetchFromGitHub +, fetchpatch +, pythonAtLeast , pythonOlder , pytestCheckHook , setuptools @@ -17,7 +19,7 @@ buildPythonPackage rec { pname = "accelerate"; - version = "0.19.0"; + version = "0.21.0"; format = "pyproject"; disabled = pythonOlder "3.7"; @@ -25,9 +27,18 @@ buildPythonPackage rec { owner = "huggingface"; repo = pname; rev = "refs/tags/v${version}"; - hash = "sha256-gW4wCpkyxoWfxXu8UHZfgopSQhOoPhGgqEqFiHJ+Db4="; + hash = "sha256-BwM3gyNhsRkxtxLNrycUGwBmXf8eq/7b56/ykMryt5w="; }; + patches = [ + # fix import error when torch>=2.0.1 and torch.distributed is disabled + # https://github.com/huggingface/accelerate/pull/1800 + (fetchpatch { + url = "https://github.com/huggingface/accelerate/commit/32701039d302d3875c50c35ab3e76c467755eae9.patch"; + hash = "sha256-Hth7qyOfx1sC8UaRdbYTnyRXD/VRKf41GtLc0ee1t2I="; + }) + ]; + nativeBuildInputs = [ setuptools ]; propagatedBuildInputs = [ @@ -53,15 +64,25 @@ buildPythonPackage rec { # try to download data: "FeatureExamplesTests" "test_infer_auto_device_map_on_t0pp" - # known failure with Torch>2.0; see https://github.com/huggingface/accelerate/pull/1339: - # (remove for next release) - "test_gradient_sync_cpu_multi" ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly "CheckpointTest" + ] ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [ + # RuntimeError: torch_shm_manager: execl failed: Permission denied + "CheckpointTest" + ] ++ lib.optionals (pythonAtLeast "3.11") [ + # python3.11 not yet supported for torch.compile + "test_dynamo_extract_model" ]; - # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException: - doCheck = !stdenv.isDarwin; + + disabledTestPaths = lib.optionals (!(stdenv.isLinux && stdenv.isx86_64)) [ + # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException: + "tests/test_cpu.py" + "tests/test_grad_sync.py" + "tests/test_metrics.py" + "tests/test_scheduler.py" + ]; + pythonImportsCheck = [ "accelerate" ];