Split test by execution time

This commit is contained in:
Robert Resch 2025-03-27 15:20:05 +01:00
parent e8aa3e6d34
commit 7b0e4871da
No known key found for this signature in database
GPG Key ID: 9D9D9DCB43120143
4 changed files with 202 additions and 77 deletions

View File

@ -893,6 +893,21 @@ jobs:
with:
python-version: ${{ env.DEFAULT_PYTHON }}
check-latest: true
- name: Generate partial pytest execution time restore key
id: generate-pytest-execution-time-key
run: |
echo "key=pytest-execution-time-${{
env.HA_SHORT_VERSION }}-$(date -u '%Y-%m-%dT%H:%M:%s')" >> $GITHUB_OUTPUT
- name: Restore pytest execution time cache
uses: actions/cache/restore@v4.2.3
with:
path: pytest-execution-time-report-${{ env.DEFAULT_PYTHON }}.json
key: >-
${{ runner.os }}-${{ env.DEFAULT_PYTHON }}-${{
steps.generate-pytest-execution-time-key.outputs.key }}
restore-keys: |
${{ runner.os }}-${{ env.DEFAULT_PYTHON }}-pytest-pytest-execution-time
-${{ env.HA_SHORT_VERSION }}-
- name: Restore base Python virtual environment
id: cache-venv
uses: actions/cache/restore@v4.2.3
@ -905,7 +920,8 @@ jobs:
- name: Run split_tests.py
run: |
. venv/bin/activate
python -m script.split_tests ${{ needs.info.outputs.test_group_count }} tests
python -m script.split_tests ${{ needs.info.outputs.test_group_count }} \
tests pytest-execution-time-report-${{ env.DEFAULT_PYTHON }}.json
- name: Upload pytest_buckets
uses: actions/upload-artifact@v4.6.2
with:
@ -1002,6 +1018,7 @@ jobs:
${cov_params[@]} \
-o console_output_style=count \
-p no:sugar \
--time-report-name pytest-time-report-${{ matrix.python-version }}-${{ matrix.group }}.json \
--exclude-warning-annotations \
$(sed -n "${{ matrix.group }},1p" pytest_buckets.txt) \
2>&1 | tee pytest-${{ matrix.python-version }}-${{ matrix.group }}.txt
@ -1010,7 +1027,9 @@ jobs:
uses: actions/upload-artifact@v4.6.2
with:
name: pytest-${{ github.run_number }}-${{ matrix.python-version }}-${{ matrix.group }}
path: pytest-*.txt
path: |
pytest-*.txt
pytest-*.json
overwrite: true
- name: Upload coverage artifact
if: needs.info.outputs.skip_coverage != 'true'
@ -1031,6 +1050,41 @@ jobs:
run: |
./script/check_dirty
pytest-combine-test-execution-time:
runs-on: ubuntu-24.04
needs:
- info
- pytest-full
name: Combine test execution times
steps:
- name: Generate partial pytest execution time restore key
id: generate-pytest-execution-time-key
run: |
echo "key=pytest-execution-time-${{env.HA_SHORT_VERSION }}-
$(date -u '%Y-%m-%dT%H:%M:%s')" >> $GITHUB_OUTPUT
- name: Download pytest execution time artifacts
uses: actions/download-artifact@v4.2.1
with:
pattern: pytest-${{ github.run_number }}-${{ env.DEFAULT_PYTHON }}-*
merge-multiple: true
- name: Combine files into one
run: |
jq 'reduce inputs as $item ({}; . *= $item)' \
pytest-execution-time-report-${{ env.DEFAULT_PYTHON }}-*.json \
> pytest-execution-time-report-${{ env.DEFAULT_PYTHON }}.json
- name: Upload combined pytest execution time artifact
uses: actions/upload-artifact@v4.6.2
with:
name: pytest-${{ github.run_number }}-${{ env.DEFAULT_PYTHON }}-time-report
path: pytest-execution-time-report-${{ env.DEFAULT_PYTHON }}.json
- name: Save pytest execution time cache
uses: actions/cache/save@v4.2.3
with:
path: pytest-execution-time-report-${{ env.DEFAULT_PYTHON }}.json
key: >-
${{ runner.os }}-${{ env.DEFAULT_PYTHON }}-${{
steps.generate-pytest-execution-time-key.outputs.key }}
pytest-mariadb:
runs-on: ubuntu-24.04
services:

3
.gitignore vendored
View File

@ -138,3 +138,6 @@ tmp_cache
# Will be created from script/split_tests.py
pytest_buckets.txt
# Contains test execution times used for splitting tests
pytest-execution-time-report.json

View File

@ -5,11 +5,10 @@ from __future__ import annotations
import argparse
from dataclasses import dataclass, field
from math import ceil
from pathlib import Path
import subprocess
import sys
from typing import Final
from typing import Final, cast
from homeassistant.util.json import load_json_object
class Bucket:
@ -19,13 +18,15 @@ class Bucket:
self,
):
"""Initialize bucket."""
self.total_tests = 0
self.approx_execution_time = 0.0
self.not_measured_files = 0
self._paths: list[str] = []
def add(self, part: TestFolder | TestFile) -> None:
"""Add tests to bucket."""
part.add_to_bucket()
self.total_tests += part.total_tests
self.approx_execution_time += part.approx_execution_time
self.not_measured_files += part.not_measured_files
self._paths.append(str(part.path))
def get_paths_line(self) -> str:
@ -36,40 +37,62 @@ class Bucket:
class BucketHolder:
"""Class to hold buckets."""
def __init__(self, tests_per_bucket: int, bucket_count: int) -> None:
def __init__(self, bucket_count: int) -> None:
"""Initialize bucket holder."""
self._tests_per_bucket = tests_per_bucket
self._bucket_count = bucket_count
self._buckets: list[Bucket] = [Bucket() for _ in range(bucket_count)]
def split_tests(self, test_folder: TestFolder) -> None:
"""Split tests into buckets."""
digits = len(str(test_folder.total_tests))
avg_execution_time = test_folder.approx_execution_time / self._bucket_count
avg_not_measured_files = test_folder.not_measured_files / self._bucket_count
digits = len(str(test_folder.approx_execution_time))
sorted_tests = sorted(
test_folder.get_all_flatten(), reverse=True, key=lambda x: x.total_tests
test_folder.get_all_flatten(),
key=lambda x: (x.not_measured_files, -x.approx_execution_time),
)
for tests in sorted_tests:
if tests.added_to_bucket:
# Already added to bucket
continue
print(f"{tests.total_tests:>{digits}} tests in {tests.path}")
smallest_bucket = min(self._buckets, key=lambda x: x.total_tests)
print(
f"{tests.approx_execution_time:>{digits}} approx execution time for {tests.path}"
)
is_file = isinstance(tests, TestFile)
for smallest_bucket in (
min(
self._buckets,
key=lambda x: (x.not_measured_files, x.approx_execution_time),
),
min(
self._buckets,
key=lambda x: (x.approx_execution_time, x.not_measured_files),
),
):
if (
smallest_bucket.total_tests + tests.total_tests < self._tests_per_bucket
(
smallest_bucket.approx_execution_time
+ tests.approx_execution_time
)
< avg_execution_time
and (smallest_bucket.not_measured_files + tests.not_measured_files)
< avg_not_measured_files
) or is_file:
smallest_bucket.add(tests)
# Ensure all files from the same folder are in the same bucket
# to ensure that syrupy correctly identifies unused snapshots
if is_file:
for other_test in tests.parent.children.values():
if other_test is tests or isinstance(other_test, TestFolder):
if other_test is tests or isinstance(
other_test, TestFolder
):
continue
print(
f"{other_test.total_tests:>{digits}} tests in {other_test.path} (same bucket)"
f"Adding {other_test.path} tests to same bucket due syrupy"
)
smallest_bucket.add(other_test)
break
# verify that all tests are added to a bucket
if not test_folder.added_to_bucket:
@ -79,7 +102,9 @@ class BucketHolder:
"""Create output file."""
with Path("pytest_buckets.txt").open("w") as file:
for idx, bucket in enumerate(self._buckets):
print(f"Bucket {idx + 1} has {bucket.total_tests} tests")
print(
f"Bucket {idx + 1} execution time is ~{bucket.approx_execution_time}s with {bucket.not_measured_files} not measured files"
)
file.write(bucket.get_paths_line())
@ -87,10 +112,11 @@ class BucketHolder:
class TestFile:
"""Class represents a single test file and the number of tests it has."""
total_tests: int
path: Path
parent: TestFolder
# 0 means not measured
approx_execution_time: float = 0.0
added_to_bucket: bool = field(default=False, init=False)
parent: TestFolder | None = field(default=None, init=False)
def add_to_bucket(self) -> None:
"""Add test file to bucket."""
@ -98,9 +124,14 @@ class TestFile:
raise ValueError("Already added to bucket")
self.added_to_bucket = True
@property
def not_measured_files(self) -> int:
"""Return files not measured."""
return 1 if self.approx_execution_time == 0 else 0
def __gt__(self, other: TestFile) -> bool:
"""Return if greater than."""
return self.total_tests > other.total_tests
return self.approx_execution_time > other.approx_execution_time
class TestFolder:
@ -112,9 +143,14 @@ class TestFolder:
self.children: dict[Path, TestFolder | TestFile] = {}
@property
def total_tests(self) -> int:
"""Return total tests."""
return sum([test.total_tests for test in self.children.values()])
def approx_execution_time(self) -> float:
"""Return approximate execution time."""
return sum([test.approx_execution_time for test in self.children.values()])
@property
def not_measured_files(self) -> int:
"""Return files not measured."""
return sum([test.not_measured_files for test in self.children.values()])
@property
def added_to_bucket(self) -> bool:
@ -130,11 +166,13 @@ class TestFolder:
def __repr__(self) -> str:
"""Return representation."""
return (
f"TestFolder(total_tests={self.total_tests}, children={len(self.children)})"
)
return f"TestFolder(approx_execution_time={self.approx_execution_time}, children={len(self.children)})"
def add_test_file(self, file: TestFile) -> None:
def add_test_file(self, path: Path, execution_time: float) -> None:
"""Add test file to folder."""
self._add_test_file(TestFile(path, self, execution_time))
def _add_test_file(self, file: TestFile) -> None:
"""Add test file to folder."""
path = file.path
file.parent = self
@ -142,7 +180,7 @@ class TestFolder:
if not relative_path.parts:
raise ValueError("Path is not a child of this folder")
if len(relative_path.parts) == 1:
if len(relative_path.parts) == 1 and path not in self.children:
self.children[path] = file
return
@ -151,7 +189,7 @@ class TestFolder:
self.children[child_path] = child = TestFolder(child_path)
elif not isinstance(child, TestFolder):
raise ValueError("Child is not a folder")
child.add_test_file(file)
child._add_test_file(file)
def get_all_flatten(self) -> list[TestFolder | TestFile]:
"""Return self and all children as flatten list."""
@ -164,35 +202,21 @@ class TestFolder:
return result
def collect_tests(path: Path) -> TestFolder:
"""Collect all tests."""
result = subprocess.run(
["pytest", "--collect-only", "-qq", "-p", "no:warnings", path],
check=False,
capture_output=True,
text=True,
)
def process_execution_time_file(
execution_time_file: Path, test_folder: TestFolder
) -> None:
"""Process the execution time file."""
for file, execution_time in load_json_object(execution_time_file).items():
test_folder.add_test_file(Path(file), cast(float, execution_time))
if result.returncode != 0:
print("Failed to collect tests:")
print(result.stderr)
print(result.stdout)
sys.exit(1)
folder = TestFolder(path)
for line in result.stdout.splitlines():
if not line.strip():
continue
file_path, _, total_tests = line.partition(": ")
if not path or not total_tests:
print(f"Unexpected line: {line}")
sys.exit(1)
file = TestFile(int(total_tests), Path(file_path))
folder.add_test_file(file)
return folder
def add_missing_test_files(folder: Path, test_folder: TestFolder) -> None:
"""Scan test folder for missing files."""
for path in folder.iterdir():
if path.is_dir():
add_missing_test_files(path, test_folder)
elif path.name.startswith("test_") and path.suffix == ".py":
test_folder.add_test_file(path, 0.0)
def main() -> None:
@ -213,24 +237,31 @@ def main() -> None:
type=check_greater_0,
)
parser.add_argument(
"path",
"test_folder",
help="Path to the test files to split into buckets",
type=Path,
)
parser.add_argument(
"execution_time_file",
help="Path to the file containing the execution time of each test",
type=Path,
)
arguments = parser.parse_args()
print("Collecting tests...")
tests = collect_tests(arguments.path)
tests_per_bucket = ceil(tests.total_tests / arguments.bucket_count)
tests = TestFolder(arguments.test_folder)
bucket_holder = BucketHolder(tests_per_bucket, arguments.bucket_count)
if arguments.execution_time_file.exists():
print(f"Using execution time file: {arguments.execution_time_file}")
process_execution_time_file(arguments.execution_time_file, tests)
print("Scanning test files...")
add_missing_test_files(arguments.test_folder, tests)
bucket_holder = BucketHolder(arguments.bucket_count)
print("Splitting tests...")
bucket_holder.split_tests(tests)
print(f"Total tests: {tests.total_tests}")
print(f"Estimated tests per bucket: {tests_per_bucket}")
bucket_holder.create_ouput_file()

View File

@ -51,6 +51,10 @@ from . import patch_recorder
# Setup patching of dt_util time functions before any other Home Assistant imports
from . import patch_time # noqa: F401, isort:skip
import json
from _pytest.terminal import TerminalReporter
from homeassistant import components, core as ha, loader, runner
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
from homeassistant.auth.models import Credentials
@ -123,6 +127,7 @@ if TYPE_CHECKING:
pytest.register_assert_rewrite("tests.common")
from .common import ( # noqa: E402, isort:skip
CLIENT_ID,
INSTANCES,
@ -153,6 +158,37 @@ asyncio.set_event_loop_policy = lambda policy: None
def pytest_addoption(parser: pytest.Parser) -> None:
"""Register custom pytest options."""
parser.addoption("--dburl", action="store", default="sqlite://")
parser.addoption(
"--execution-time-report-name",
action="store",
default="pytest-execution-time-report.json",
)
class PytestExecutionTimeReport:
"""Pytest plugin to generate a JSON report with the execution time of each test."""
def pytest_terminal_summary(
self,
terminalreporter: TerminalReporter,
exitstatus: pytest.ExitCode,
config: pytest.Config,
) -> None:
"""Generate a JSON report with the execution time of each test."""
if config.option.collectonly:
return
raw_data: dict[str, list[float]] = {}
for replist in terminalreporter.stats.values():
for rep in replist:
if isinstance(rep, pytest.TestReport):
raw_data.setdefault(rep.location[0], []).append(rep.duration)
data = {filename: sum(values) for filename, values in raw_data.items()}
time_report_filename = config.option.execution_time_report_name
file = pathlib.Path(__file__).parents[1].joinpath(time_report_filename)
with open(file, "w", encoding="utf-8") as fp:
json.dump(data, fp, indent=2)
def pytest_configure(config: pytest.Config) -> None:
@ -167,6 +203,7 @@ def pytest_configure(config: pytest.Config) -> None:
# Temporary workaround until it is finalised inside syrupy
# See https://github.com/syrupy-project/syrupy/pull/901
SnapshotSession.finish = override_syrupy_finish
config.pluginmanager.register(PytestExecutionTimeReport())
def pytest_runtest_setup() -> None: