155 lines
5 KiB
Nix
155 lines
5 KiB
Nix
# Portions of this file are sourced from
|
|
# https://github.com/TyberiusPrime/uv2nix_hammer_overrides (MIT License)
|
|
{
|
|
outputs =
|
|
inputs@{ flake-parts, rebmit, ... }:
|
|
flake-parts.lib.mkFlake { inherit inputs; } {
|
|
inherit (rebmit.lib) systems;
|
|
imports = [
|
|
inputs.devshell.flakeModule
|
|
inputs.git-hooks-nix.flakeModule
|
|
inputs.treefmt-nix.flakeModule
|
|
inputs.rebmit.flakeModule
|
|
];
|
|
perSystem =
|
|
{
|
|
config,
|
|
pkgs,
|
|
lib,
|
|
...
|
|
}:
|
|
let
|
|
workspace = inputs.uv2nix.lib.workspace.loadWorkspace { workspaceRoot = ./.; };
|
|
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
|
|
pyprojectOverrides = final: prev: {
|
|
nvidia-cusolver-cu12 = prev.nvidia-cusolver-cu12.overrideAttrs (old: {
|
|
buildInputs = old.buildInputs or [ ] ++ [
|
|
pkgs.cudaPackages_12_4.libcublas
|
|
pkgs.cudaPackages_12_4.libcusparse
|
|
pkgs.cudaPackages_12_4.libnvjitlink
|
|
];
|
|
});
|
|
nvidia-cusparse-cu12 = prev.nvidia-cusparse-cu12.overrideAttrs (old: {
|
|
buildInputs = old.buildInputs or [ ] ++ [
|
|
pkgs.cudaPackages_12_4.libnvjitlink
|
|
];
|
|
});
|
|
torch = prev.torch.overrideAttrs (old: {
|
|
buildInputs =
|
|
old.buildInputs or [ ]
|
|
++ (pkgs.lib.optionals (pkgs.stdenv.hostPlatform.system == "x86_64-linux") [
|
|
pkgs.cudaPackages_12_4.cuda_cudart
|
|
])
|
|
++ [
|
|
pkgs.cudaPackages_12_4.cuda_cupti
|
|
pkgs.cudaPackages_12_4.cuda_nvrtc
|
|
pkgs.cudaPackages_12_4.cudnn
|
|
pkgs.cudaPackages_12_4.libcublas
|
|
pkgs.cudaPackages_12_4.libcufft
|
|
pkgs.cudaPackages_12_4.libcurand
|
|
pkgs.cudaPackages_12_4.libcusolver
|
|
pkgs.cudaPackages_12_4.libcusparse
|
|
pkgs.cudaPackages_12_4.nccl
|
|
];
|
|
});
|
|
torchvision = prev.torchvision.overrideAttrs (old: {
|
|
buildInputs =
|
|
old.buildInputs or [ ]
|
|
++ (pkgs.lib.optionals (pkgs.stdenv.hostPlatform.system == "x86_64-linux") [
|
|
pkgs.cudaPackages_12_4.cuda_cudart
|
|
]);
|
|
preFixup = pkgs.lib.optionals (!pkgs.stdenv.isDarwin) ''
|
|
addAutoPatchelfSearchPath "${final.torch}/${final.python.sitePackages}/torch/lib"
|
|
'';
|
|
});
|
|
};
|
|
pythonSet =
|
|
(pkgs.callPackage inputs.pyproject-nix.build.packages {
|
|
python = pkgs.python311;
|
|
}).overrideScope
|
|
(
|
|
lib.composeManyExtensions [
|
|
inputs.pyproject-build-systems.overlays.default
|
|
overlay
|
|
pyprojectOverrides
|
|
]
|
|
);
|
|
in
|
|
{
|
|
devshells.default = {
|
|
packages = [
|
|
(pythonSet.mkVirtualEnv "python-uv-env" workspace.deps.all)
|
|
pkgs.python311Packages.uv
|
|
pkgs.just
|
|
config.treefmt.build.wrapper
|
|
];
|
|
env = [
|
|
(lib.nameValuePair "UV_PYTHON" "${lib.getExe pkgs.python311}")
|
|
(lib.nameValuePair "DEVSHELL_NO_MOTD" 1)
|
|
{
|
|
name = "PYTHONPATH";
|
|
unset = true;
|
|
}
|
|
];
|
|
devshell.startup.pre-commit-hook.text = config.pre-commit.installationScript;
|
|
};
|
|
|
|
nixpkgs.config = {
|
|
allowUnfree = true;
|
|
};
|
|
|
|
treefmt = {
|
|
projectRootFile = "flake.nix";
|
|
programs = {
|
|
nixfmt.enable = true;
|
|
ruff-check.enable = true;
|
|
ruff-format.enable = true;
|
|
};
|
|
};
|
|
|
|
pre-commit.settings.hooks.treefmt = {
|
|
enable = true;
|
|
name = "treefmt";
|
|
entry = lib.getExe config.treefmt.build.wrapper;
|
|
pass_filenames = false;
|
|
};
|
|
};
|
|
};
|
|
|
|
inputs = {
|
|
# flake-parts
|
|
|
|
flake-parts.follows = "rebmit/flake-parts";
|
|
|
|
# nixpkgs
|
|
|
|
nixpkgs.follows = "rebmit/nixpkgs";
|
|
nixpkgs-unstable.follows = "rebmit/nixpkgs-unstable";
|
|
|
|
# flake modules
|
|
|
|
devshell.follows = "rebmit/devshell";
|
|
git-hooks-nix.follows = "rebmit/git-hooks-nix";
|
|
treefmt-nix.follows = "rebmit/treefmt-nix";
|
|
|
|
# libraries
|
|
|
|
rebmit.url = "github:rebmit/nix-exprs";
|
|
pyproject-nix = {
|
|
url = "github:pyproject-nix/pyproject.nix";
|
|
inputs.nixpkgs.follows = "nixpkgs";
|
|
};
|
|
uv2nix = {
|
|
url = "github:pyproject-nix/uv2nix";
|
|
inputs.pyproject-nix.follows = "pyproject-nix";
|
|
inputs.nixpkgs.follows = "nixpkgs";
|
|
};
|
|
pyproject-build-systems = {
|
|
url = "github:pyproject-nix/build-system-pkgs";
|
|
inputs.pyproject-nix.follows = "pyproject-nix";
|
|
inputs.uv2nix.follows = "uv2nix";
|
|
inputs.nixpkgs.follows = "nixpkgs";
|
|
};
|
|
};
|
|
}
|