Files
2026-06-25 21:29:21 +00:00

99 lines
3.1 KiB
Python

#!/usr/bin/env python3
"""Sync Requirements - Automatically upgrade test requirements pinned
versions from pre-commit config file."""
from __future__ import annotations
import sys
from pathlib import Path
from typing import TYPE_CHECKING
from yaml import load as load_yaml
if TYPE_CHECKING:
from collections.abc import Generator
from yaml import CLoader as _CLoader, Loader as _Loader
Loader: type[_CLoader | _Loader]
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
def yield_pre_commit_version_data(
pre_commit_text: str,
) -> Generator[tuple[str, str], None, None]:
"""Yield (name, rev) tuples from pre-commit config file."""
pre_commit_config = load_yaml(pre_commit_text, Loader)
for repo in pre_commit_config["repos"]:
if "repo" not in repo or "rev" not in repo:
continue
url = repo["repo"]
name = url.rsplit("/", 1)[-1]
rev = repo["rev"].removeprefix("v")
yield name, rev
def update_requirements(
requirements: Path,
version_data: dict[str, str],
) -> bool:
"""Return if updated requirements file.
Update requirements file to match versions in version_data."""
changed = False
old_lines = requirements.read_text(encoding="utf-8").splitlines(True)
with requirements.open("w", encoding="utf-8") as file:
for line in old_lines:
# If comment or not version mark line, ignore.
if line.startswith("#") or "==" not in line:
file.write(line)
continue
name, rest = line.split("==", 1)
# Maintain extra markers if they exist
old_version = rest.strip()
extra = "\n"
if ";" in rest:
old_version, extra = rest.split(";", 1)
old_version = old_version.strip()
extra = " ;" + extra
version = version_data.get(name)
# If does not exist, skip
if version is None:
file.write(line)
continue
# Otherwise might have changed
new_line = f"{name}=={version}{extra}"
if new_line != line:
if not changed:
changed = True
print("Changed test requirements version to match pre-commit")
print(f"{name}=={old_version} -> {name}=={version}")
file.write(new_line)
return changed
if __name__ == "__main__":
source_root = Path.cwd().absolute()
# Double-check we found the right directory
assert (source_root / "LICENSE").exists()
pre_commit = source_root / ".pre-commit-config.yaml"
test_requirements = source_root / "test-requirements.txt"
pre_commit_text = pre_commit.read_text(encoding="utf-8")
# Get tool versions from pre-commit
# Get correct names
pre_commit_versions = {
name.removesuffix("-mirror").removesuffix("-pre-commit"): version
for name, version in yield_pre_commit_version_data(pre_commit_text)
}
changed = update_requirements(test_requirements, pre_commit_versions)
sys.exit(int(changed))