From 64636a4310f4765c7477e3588b025b85a8d4586a Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 9 May 2022 14:45:53 -0700 Subject: [PATCH] Add service entity context (#71558) Co-authored-by: Shay Levy --- homeassistant/helpers/service.py | 11 +++++++++++ tests/helpers/test_service.py | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 4cd38aa9768..975b05067b2 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable, Iterable +from contextvars import ContextVar import dataclasses from functools import partial, wraps import logging @@ -63,6 +64,15 @@ _LOGGER = logging.getLogger(__name__) SERVICE_DESCRIPTION_CACHE = "service_description_cache" +_current_entity: ContextVar[str | None] = ContextVar("current_entity", default=None) + + +@callback +def async_get_current_entity() -> str | None: + """Get the current entity on which the service is called.""" + return _current_entity.get() + + class ServiceParams(TypedDict): """Type for service call parameters.""" @@ -706,6 +716,7 @@ async def _handle_entity_call( ) -> None: """Handle calling service method.""" entity.async_set_context(context) + _current_entity.set(entity.entity_id) if isinstance(func, str): result = hass.async_run_job(partial(getattr(entity, func), **data)) # type: ignore[arg-type] diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 1b8de6ca6e2..cf87377dd8f 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -19,12 +19,12 @@ from homeassistant.const import ( STATE_ON, ) from homeassistant.helpers import ( + config_validation as cv, device_registry as dev_reg, entity_registry as ent_reg, service, template, ) -import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import EntityCategory from homeassistant.setup import async_setup_component @@ -1205,3 +1205,17 @@ async def test_async_extract_config_entry_ids(hass): ) assert await service.async_extract_config_entry_ids(hass, call) == {"abc"} + + +async def test_current_entity_context(hass, mock_entities): + """Test we set the current entity context var.""" + + async def mock_service(entity, call): + assert entity.entity_id == service.async_get_current_entity() + + await service.entity_service_call( + hass, + [Mock(entities=mock_entities)], + mock_service, + ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}), + )