From 91157c21efb76e226510e8c83195214f73fc788d Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Mon, 4 Nov 2024 18:59:27 +0100 Subject: [PATCH] Reapply "Fix unused snapshots not triggering failure in CI" (#129311) --- .github/workflows/ci.yaml | 4 + tests/conftest.py | 8 +- tests/syrupy.py | 169 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 02e8b4f180d..cae9795d715 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -949,6 +949,7 @@ jobs: --timeout=9 \ --durations=10 \ --numprocesses auto \ + --snapshot-details \ --dist=loadfile \ ${cov_params[@]} \ -o console_output_style=count \ @@ -1071,6 +1072,7 @@ jobs: -qq \ --timeout=20 \ --numprocesses 1 \ + --snapshot-details \ ${cov_params[@]} \ -o console_output_style=count \ --durations=10 \ @@ -1199,6 +1201,7 @@ jobs: -qq \ --timeout=9 \ --numprocesses 1 \ + --snapshot-details \ ${cov_params[@]} \ -o console_output_style=count \ --durations=0 \ @@ -1345,6 +1348,7 @@ jobs: -qq \ --timeout=9 \ --numprocesses auto \ + --snapshot-details \ ${cov_params[@]} \ -o console_output_style=count \ --durations=0 \ diff --git a/tests/conftest.py b/tests/conftest.py index 10c9a740256..c60018413e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,6 +36,7 @@ import pytest_socket import requests_mock import respx from syrupy.assertion import SnapshotAssertion +from syrupy.session import SnapshotSession from homeassistant import block_async_io from homeassistant.exceptions import ServiceNotFound @@ -92,7 +93,7 @@ from homeassistant.util.async_ import create_eager_task, get_scheduled_timer_han from homeassistant.util.json import json_loads from .ignore_uncaught_exceptions import IGNORE_UNCAUGHT_EXCEPTIONS -from .syrupy import HomeAssistantSnapshotExtension +from .syrupy import HomeAssistantSnapshotExtension, override_syrupy_finish from .typing import ( ClientSessionGenerator, MockHAClientWebSocket, @@ -149,6 +150,11 @@ def pytest_configure(config: pytest.Config) -> None: if config.getoption("verbose") > 0: logging.getLogger().setLevel(logging.DEBUG) + # Override default finish to detect unused snapshots despite xdist + # Temporary workaround until it is finalised inside syrupy + # See https://github.com/syrupy-project/syrupy/pull/901 + SnapshotSession.finish = override_syrupy_finish + def pytest_runtest_setup() -> None: """Prepare pytest_socket and freezegun. diff --git a/tests/syrupy.py b/tests/syrupy.py index 268ee59243f..a3b3f763063 100644 --- a/tests/syrupy.py +++ b/tests/syrupy.py @@ -5,14 +5,22 @@ from __future__ import annotations from contextlib import suppress import dataclasses from enum import IntFlag +import json +import os from pathlib import Path from typing import Any import attr import attrs +import pytest +from syrupy.constants import EXIT_STATUS_FAIL_UNUSED +from syrupy.data import Snapshot, SnapshotCollection, SnapshotCollections from syrupy.extensions.amber import AmberDataSerializer, AmberSnapshotExtension from syrupy.location import PyTestLocation +from syrupy.report import SnapshotReport +from syrupy.session import ItemStatus, SnapshotSession from syrupy.types import PropertyFilter, PropertyMatcher, PropertyPath, SerializableData +from syrupy.utils import is_xdist_controller, is_xdist_worker import voluptuous as vol import voluptuous_serialize @@ -246,3 +254,164 @@ class HomeAssistantSnapshotExtension(AmberSnapshotExtension): """ test_dir = Path(test_location.filepath).parent return str(test_dir.joinpath("snapshots")) + + +# Classes and Methods to override default finish behavior in syrupy +# This is needed to handle the xdist plugin in pytest +# The default implementation does not handle the xdist plugin +# and will not work correctly when running tests in parallel +# with pytest-xdist. +# Temporary workaround until it is finalised inside syrupy +# See https://github.com/syrupy-project/syrupy/pull/901 + + +class _FakePytestObject: + """Fake object.""" + + def __init__(self, collected_item: dict[str, str]) -> None: + """Initialise fake object.""" + self.__module__ = collected_item["modulename"] + self.__name__ = collected_item["methodname"] + + +class _FakePytestItem: + """Fake pytest.Item object.""" + + def __init__(self, collected_item: dict[str, str]) -> None: + """Initialise fake pytest.Item object.""" + self.nodeid = collected_item["nodeid"] + self.name = collected_item["name"] + self.path = Path(collected_item["path"]) + self.obj = _FakePytestObject(collected_item) + + +def _serialize_collections(collections: SnapshotCollections) -> dict[str, Any]: + return { + k: [c.name for c in v] for k, v in collections._snapshot_collections.items() + } + + +def _serialize_report( + report: SnapshotReport, + collected_items: set[pytest.Item], + selected_items: dict[str, ItemStatus], +) -> dict[str, Any]: + return { + "discovered": _serialize_collections(report.discovered), + "created": _serialize_collections(report.created), + "failed": _serialize_collections(report.failed), + "matched": _serialize_collections(report.matched), + "updated": _serialize_collections(report.updated), + "used": _serialize_collections(report.used), + "_collected_items": [ + { + "nodeid": c.nodeid, + "name": c.name, + "path": str(c.path), + "modulename": c.obj.__module__, + "methodname": c.obj.__name__, + } + for c in list(collected_items) + ], + "_selected_items": { + key: status.value for key, status in selected_items.items() + }, + } + + +def _merge_serialized_collections( + collections: SnapshotCollections, json_data: dict[str, list[str]] +) -> None: + if not json_data: + return + for location, names in json_data.items(): + snapshot_collection = SnapshotCollection(location=location) + for name in names: + snapshot_collection.add(Snapshot(name)) + collections.update(snapshot_collection) + + +def _merge_serialized_report(report: SnapshotReport, json_data: dict[str, Any]) -> None: + _merge_serialized_collections(report.discovered, json_data["discovered"]) + _merge_serialized_collections(report.created, json_data["created"]) + _merge_serialized_collections(report.failed, json_data["failed"]) + _merge_serialized_collections(report.matched, json_data["matched"]) + _merge_serialized_collections(report.updated, json_data["updated"]) + _merge_serialized_collections(report.used, json_data["used"]) + for collected_item in json_data["_collected_items"]: + custom_item = _FakePytestItem(collected_item) + if not any( + t.nodeid == custom_item.nodeid and t.name == custom_item.nodeid + for t in report.collected_items + ): + report.collected_items.add(custom_item) + for key, selected_item in json_data["_selected_items"].items(): + if key in report.selected_items: + status = ItemStatus(selected_item) + if status != ItemStatus.NOT_RUN: + report.selected_items[key] = status + else: + report.selected_items[key] = ItemStatus(selected_item) + + +def override_syrupy_finish(self: SnapshotSession) -> int: + """Override the finish method to allow for custom handling.""" + exitstatus = 0 + self.flush_snapshot_write_queue() + self.report = SnapshotReport( + base_dir=self.pytest_session.config.rootpath, + collected_items=self._collected_items, + selected_items=self._selected_items, + assertions=self._assertions, + options=self.pytest_session.config.option, + ) + + needs_xdist_merge = self.update_snapshots or bool( + self.pytest_session.config.option.include_snapshot_details + ) + + if is_xdist_worker(): + if not needs_xdist_merge: + return exitstatus + with open(".pytest_syrupy_worker_count", "w", encoding="utf-8") as f: + f.write(os.getenv("PYTEST_XDIST_WORKER_COUNT")) + with open( + f".pytest_syrupy_{os.getenv("PYTEST_XDIST_WORKER")}_result", + "w", + encoding="utf-8", + ) as f: + json.dump( + _serialize_report( + self.report, self._collected_items, self._selected_items + ), + f, + indent=2, + ) + return exitstatus + if is_xdist_controller(): + return exitstatus + + if needs_xdist_merge: + worker_count = None + try: + with open(".pytest_syrupy_worker_count", encoding="utf-8") as f: + worker_count = f.read() + os.remove(".pytest_syrupy_worker_count") + except FileNotFoundError: + pass + + if worker_count: + for i in range(int(worker_count)): + with open(f".pytest_syrupy_gw{i}_result", encoding="utf-8") as f: + _merge_serialized_report(self.report, json.load(f)) + os.remove(f".pytest_syrupy_gw{i}_result") + + if self.report.num_unused: + if self.update_snapshots: + self.remove_unused_snapshots( + unused_snapshot_collections=self.report.unused, + used_snapshot_collections=self.report.used, + ) + elif not self.warn_unused_snapshots: + exitstatus |= EXIT_STATUS_FAIL_UNUSED + return exitstatus