diff --git a/pkgs/development/python-modules/torchdiffeq/default.nix b/pkgs/development/python-modules/torchdiffeq/default.nix new file mode 100644 index 00000000000..8195d750c6f --- /dev/null +++ b/pkgs/development/python-modules/torchdiffeq/default.nix @@ -0,0 +1,43 @@ +{ lib +, buildPythonPackage +, fetchPypi + +# dependencies +, torch +, scipy + +# tests +, pytestCheckHook +}: + +buildPythonPackage rec { + pname = "torchdiffeq"; + version = "0.2.3"; + format = "setuptools"; + + src = fetchPypi { + inherit pname version; + hash = "sha256-/nX0NLkJCsDCdwLgK+0hRysPhwNb5lgfUe3F1AE+oxo="; + }; + + propagatedBuildInputs = [ + torch + scipy + ]; + + pythonImportsCheck = [ "torchdiffeq" ]; + + # no tests in sdist, no tags on git + doCheck = false; + + nativeCheckInputs = [ + pytestCheckHook + ]; + + meta = with lib; { + description = "Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation"; + homepage = "https://github.com/rtqichen/torchdiffeq"; + license = licenses.mit; + maintainers = teams.tts.members; + }; +} diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index abd2c52e5ed..253818700ca 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -12149,6 +12149,8 @@ self: super: with self; { torchaudio-bin = callPackage ../development/python-modules/torchaudio/bin.nix { }; + torchdiffeq = callPackage ../development/python-modules/torchdiffeq { }; + torchgpipe = callPackage ../development/python-modules/torchgpipe { }; torchmetrics = callPackage ../development/python-modules/torchmetrics { };