Improve script

This commit is contained in:
Robert Resch 2025-03-28 22:40:12 +01:00
parent 98f32c204b
commit 16c56d9f6b
No known key found for this signature in database
GPG Key ID: 9D9D9DCB43120143

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import argparse
from dataclasses import dataclass, field
from datetime import timedelta
from pathlib import Path
from typing import Final, cast
@ -18,7 +19,7 @@ class Bucket:
self,
):
"""Initialize bucket."""
self.approx_execution_time = 0.0
self.approx_execution_time = timedelta(seconds=0)
self.not_measured_files = 0
self._paths: list[str] = []
@ -46,6 +47,16 @@ def add_not_measured_files(
not_measured_files.add(test)
def sort_by_not_measured(bucket: Bucket) -> tuple[int, float]:
"""Sort by not measured files."""
return (bucket.not_measured_files, bucket.approx_execution_time.total_seconds())
def sort_by_execution_time(bucket: Bucket) -> tuple[float, int]:
"""Sort by execution time."""
return (bucket.approx_execution_time.total_seconds(), bucket.not_measured_files)
class BucketHolder:
"""Class to hold buckets."""
@ -58,7 +69,6 @@ class BucketHolder:
"""Split tests into buckets."""
avg_execution_time = test_folder.approx_execution_time / self._bucket_count
avg_not_measured_files = test_folder.not_measured_files / self._bucket_count
digits = len(str(round(test_folder.approx_execution_time, 0)))
sorted_tests = sorted(
test_folder.get_all_flatten(),
key=lambda x: (
@ -68,28 +78,22 @@ class BucketHolder:
),
)
not_measured_tests = set()
bucket_sort_keys = (
lambda x: (x.not_measured_files, x.approx_execution_time),
lambda x: (x.approx_execution_time, x.not_measured_files),
)
for tests in sorted_tests:
if tests.added_to_bucket:
# Already added to bucket
continue
print(
f"~{round(tests.approx_execution_time, 2):>{digits}}s execution time for {tests.path}"
)
print(f"~{tests.approx_execution_time} execution time for {tests.path}")
is_file = isinstance(tests, TestFile)
for sort_key in bucket_sort_keys:
smallest_bucket = min(self._buckets, key=sort_key)
sort_key = sort_by_execution_time
if tests.not_measured_files and tests.approx_execution_time == 0:
# If tests are not measured, sort by not measured files
sort_key = sort_by_not_measured
smallest_bucket = min(self._buckets, key=sort_key)
if (
(
smallest_bucket.approx_execution_time
+ tests.approx_execution_time
)
(smallest_bucket.approx_execution_time + tests.approx_execution_time)
< avg_execution_time
and (smallest_bucket.not_measured_files + tests.not_measured_files)
< avg_not_measured_files
@ -103,19 +107,16 @@ class BucketHolder:
# to ensure that syrupy correctly identifies unused snapshots
if is_file:
for other_test in tests.parent.children.values():
if other_test is tests or isinstance(
other_test, TestFolder
):
if other_test is tests or isinstance(other_test, TestFolder):
continue
print(
f"Adding {other_test.path} tests to same bucket due syrupy"
)
smallest_bucket.add(other_test)
add_not_measured_files(
tests,
other_test,
not_measured_tests,
)
break
# verify that all tests are added to a bucket
if not test_folder.added_to_bucket:
@ -131,12 +132,17 @@ class BucketHolder:
with Path("pytest_buckets.txt").open("w") as file:
for idx, bucket in enumerate(self._buckets):
print(
f"Bucket {idx + 1} execution time should be ~{bucket.approx_execution_time}s"
f"Bucket {idx + 1} execution time should be ~{str_without_milliseconds(bucket.approx_execution_time)}"
f" with {bucket.not_measured_files} not measured files"
)
file.write(bucket.get_paths_line())
def str_without_milliseconds(td: timedelta) -> str:
"""Return str without milliseconds."""
return str(td).split(".")[0]
@dataclass
class TestFile:
"""Class represents a single test file and the number of tests it has."""
@ -144,7 +150,7 @@ class TestFile:
path: Path
parent: TestFolder
# 0 means not measured
approx_execution_time: float = 0.0
approx_execution_time: timedelta
added_to_bucket: bool = field(default=False, init=False)
def add_to_bucket(self) -> None:
@ -156,7 +162,7 @@ class TestFile:
@property
def not_measured_files(self) -> int:
"""Return files not measured."""
return 1 if self.approx_execution_time == 0 else 0
return 1 if self.approx_execution_time.total_seconds() == 0 else 0
def __gt__(self, other: TestFile) -> bool:
"""Return if greater than."""
@ -176,9 +182,12 @@ class TestFolder:
self.children: dict[Path, TestFolder | TestFile] = {}
@property
def approx_execution_time(self) -> float:
def approx_execution_time(self) -> timedelta:
"""Return approximate execution time."""
return sum([test.approx_execution_time for test in self.children.values()])
time = timedelta(seconds=0)
for test in self.children.values():
time += test.approx_execution_time
return time
@property
def not_measured_files(self) -> int:
@ -213,7 +222,10 @@ class TestFolder:
self, path: Path, execution_time: float, skip_file_if_present: bool
) -> None:
"""Add test file to folder."""
self._add_test_file(TestFile(path, self, execution_time), skip_file_if_present)
self._add_test_file(
TestFile(path, self, timedelta(seconds=execution_time)),
skip_file_if_present,
)
def _add_test_file(self, file: TestFile, skip_file_if_present: bool) -> None:
"""Add test file to folder."""