nix-exprs/templates/python-uv-torch-cu124/flake.nix

155 lines
5 KiB
Nix

# Portions of this file are sourced from
# https://github.com/TyberiusPrime/uv2nix_hammer_overrides (MIT License)
{
outputs =
inputs@{ flake-parts, rebmit, ... }:
flake-parts.lib.mkFlake { inherit inputs; } {
inherit (rebmit.lib) systems;
imports = [
inputs.devshell.flakeModule
inputs.git-hooks-nix.flakeModule
inputs.treefmt-nix.flakeModule
inputs.rebmit.flakeModule
];
perSystem =
{
config,
pkgs,
lib,
...
}:
let
workspace = inputs.uv2nix.lib.workspace.loadWorkspace { workspaceRoot = ./.; };
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
pyprojectOverrides = final: prev: {
nvidia-cusolver-cu12 = prev.nvidia-cusolver-cu12.overrideAttrs (old: {
buildInputs = old.buildInputs or [ ] ++ [
pkgs.cudaPackages_12_4.libcublas
pkgs.cudaPackages_12_4.libcusparse
pkgs.cudaPackages_12_4.libnvjitlink
];
});
nvidia-cusparse-cu12 = prev.nvidia-cusparse-cu12.overrideAttrs (old: {
buildInputs = old.buildInputs or [ ] ++ [
pkgs.cudaPackages_12_4.libnvjitlink
];
});
torch = prev.torch.overrideAttrs (old: {
buildInputs =
old.buildInputs or [ ]
++ (pkgs.lib.optionals (pkgs.stdenv.hostPlatform.system == "x86_64-linux") [
pkgs.cudaPackages_12_4.cuda_cudart
])
++ [
pkgs.cudaPackages_12_4.cuda_cupti
pkgs.cudaPackages_12_4.cuda_nvrtc
pkgs.cudaPackages_12_4.cudnn
pkgs.cudaPackages_12_4.libcublas
pkgs.cudaPackages_12_4.libcufft
pkgs.cudaPackages_12_4.libcurand
pkgs.cudaPackages_12_4.libcusolver
pkgs.cudaPackages_12_4.libcusparse
pkgs.cudaPackages_12_4.nccl
];
});
torchvision = prev.torchvision.overrideAttrs (old: {
buildInputs =
old.buildInputs or [ ]
++ (pkgs.lib.optionals (pkgs.stdenv.hostPlatform.system == "x86_64-linux") [
pkgs.cudaPackages_12_4.cuda_cudart
]);
preFixup = pkgs.lib.optionals (!pkgs.stdenv.isDarwin) ''
addAutoPatchelfSearchPath "${final.torch}/${final.python.sitePackages}/torch/lib"
'';
});
};
pythonSet =
(pkgs.callPackage inputs.pyproject-nix.build.packages {
python = pkgs.python311;
}).overrideScope
(
lib.composeManyExtensions [
inputs.pyproject-build-systems.overlays.default
overlay
pyprojectOverrides
]
);
in
{
devshells.default = {
packages = [
(pythonSet.mkVirtualEnv "python-uv-env" workspace.deps.all)
pkgs.python311Packages.uv
pkgs.just
config.treefmt.build.wrapper
];
env = [
(lib.nameValuePair "UV_PYTHON" "${lib.getExe pkgs.python311}")
(lib.nameValuePair "DEVSHELL_NO_MOTD" 1)
{
name = "PYTHONPATH";
unset = true;
}
];
devshell.startup.pre-commit-hook.text = config.pre-commit.installationScript;
};
nixpkgs.config = {
allowUnfree = true;
};
treefmt = {
projectRootFile = "flake.nix";
programs = {
nixfmt.enable = true;
ruff-check.enable = true;
ruff-format.enable = true;
};
};
pre-commit.settings.hooks.treefmt = {
enable = true;
name = "treefmt";
entry = lib.getExe config.treefmt.build.wrapper;
pass_filenames = false;
};
};
};
inputs = {
# flake-parts
flake-parts.follows = "rebmit/flake-parts";
# nixpkgs
nixpkgs.follows = "rebmit/nixpkgs";
nixpkgs-unstable.follows = "rebmit/nixpkgs-unstable";
# flake modules
devshell.follows = "rebmit/devshell";
git-hooks-nix.follows = "rebmit/git-hooks-nix";
treefmt-nix.follows = "rebmit/treefmt-nix";
# libraries
rebmit.url = "github:rebmit/nix-exprs";
pyproject-nix = {
url = "github:pyproject-nix/pyproject.nix";
inputs.nixpkgs.follows = "nixpkgs";
};
uv2nix = {
url = "github:pyproject-nix/uv2nix";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
pyproject-build-systems = {
url = "github:pyproject-nix/build-system-pkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.uv2nix.follows = "uv2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
};
}