From 61b2e4ca323c0f8bf37a0202c407542568e3d8ca Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Mon, 1 Jul 2024 14:05:30 +0200 Subject: [PATCH] Add Context to service_calls fixture (#120923) --- tests/conftest.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f9b65c5f138..3cef2dd0279 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,6 +55,7 @@ from homeassistant.config import YAML_CONFIG_FILE from homeassistant.config_entries import ConfigEntries, ConfigEntry, ConfigEntryState from homeassistant.const import HASSIO_USER_NAME from homeassistant.core import ( + Context, CoreState, HassJob, HomeAssistant, @@ -1661,7 +1662,7 @@ def label_registry(hass: HomeAssistant) -> lr.LabelRegistry: @pytest.fixture -def service_calls(hass: HomeAssistant) -> Generator[None, None, list[ServiceCall]]: +def service_calls(hass: HomeAssistant) -> Generator[list[ServiceCall]]: """Track all service calls.""" calls = [] @@ -1672,15 +1673,23 @@ def service_calls(hass: HomeAssistant) -> Generator[None, None, list[ServiceCall domain: str, service: str, service_data: dict[str, Any] | None = None, - **kwargs: Any, + blocking: bool = False, + context: Context | None = None, + target: dict[str, Any] | None = None, + return_response: bool = False, ) -> ServiceResponse: - calls.append(ServiceCall(domain, service, service_data)) + calls.append( + ServiceCall(domain, service, service_data, context, return_response) + ) try: return await _original_async_call( domain, service, service_data, - **kwargs, + blocking, + context, + target, + return_response, ) except ha.ServiceNotFound: _LOGGER.debug("Ignoring unknown service call to %s.%s", domain, service) @@ -1697,7 +1706,7 @@ def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: @pytest.fixture -def disable_block_async_io() -> Generator[Any, Any, None]: +def disable_block_async_io() -> Generator[None]: """Fixture to disable the loop protection from block_async_io.""" yield calls = block_async_io._BLOCKED_CALLS.calls