From b7d9f26cee5f189213adeee07b4d6ef396a5256f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 8 Mar 2024 22:49:08 -1000 Subject: [PATCH] Cache the job type for entity service calls (#112793) --- homeassistant/core.py | 4 ++-- homeassistant/helpers/entity.py | 17 +++++++++++++++++ homeassistant/helpers/service.py | 5 ++++- tests/helpers/test_entity.py | 31 ++++++++++++++++++++++++++++++- 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/homeassistant/core.py b/homeassistant/core.py index b906f458bf3..6169df32cfb 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -299,7 +299,7 @@ class HassJob(Generic[_P, _R_co]): @cached_property def job_type(self) -> HassJobType: """Return the job type.""" - return self._job_type or _get_hassjob_callable_job_type(self.target) + return self._job_type or get_hassjob_callable_job_type(self.target) @property def cancel_on_shutdown(self) -> bool | None: @@ -319,7 +319,7 @@ class HassJobWithArgs: args: Iterable[Any] -def _get_hassjob_callable_job_type(target: Callable[..., Any]) -> HassJobType: +def get_hassjob_callable_job_type(target: Callable[..., Any]) -> HassJobType: """Determine the job type from the callable.""" # Check for partials to properly determine if coroutine function check_target = target diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 137d9075b65..191882b7afa 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -51,6 +51,7 @@ from homeassistant.core import ( HassJobType, HomeAssistant, callback, + get_hassjob_callable_job_type, get_release_channel, ) from homeassistant.exceptions import ( @@ -527,6 +528,8 @@ class Entity( __combined_unrecorded_attributes: frozenset[str] = ( _entity_component_unrecorded_attributes | _unrecorded_attributes ) + # Job type cache + _job_types: dict[str, HassJobType] | None = None # StateInfo. Set by EntityPlatform by calling async_internal_added_to_hass # While not purely typed, it makes typehinting more useful for us @@ -568,6 +571,20 @@ class Entity( cls._entity_component_unrecorded_attributes | cls._unrecorded_attributes ) + def get_hassjob_type(self, function_name: str) -> HassJobType: + """Get the job type function for the given name. + + This is used for entity service calls to avoid + figuring out the job type each time. + """ + if not self._job_types: + self._job_types = {} + if function_name not in self._job_types: + self._job_types[function_name] = get_hassjob_callable_job_type( + getattr(self, function_name) + ) + return self._job_types[function_name] + @cached_property def should_poll(self) -> bool: """Return True if entity has to be polled for state. diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 223833fc5a5..d954d7b9682 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -963,7 +963,10 @@ async def _handle_entity_call( task: asyncio.Future[ServiceResponse] | None if isinstance(func, str): - job = HassJob(partial(getattr(entity, func), **data)) # type: ignore[arg-type] + job = HassJob( + partial(getattr(entity, func), **data), # type: ignore[arg-type] + job_type=entity.get_hassjob_type(func), + ) task = hass.async_run_hass_job(job, eager_start=True) else: task = hass.async_run_hass_job(func, entity, data, eager_start=True) diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 241a26e6529..ec281bf4c0d 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -23,7 +23,13 @@ from homeassistant.const import ( STATE_UNAVAILABLE, STATE_UNKNOWN, ) -from homeassistant.core import Context, HomeAssistant, HomeAssistantError +from homeassistant.core import ( + Context, + HassJobType, + HomeAssistant, + HomeAssistantError, + callback, +) from homeassistant.helpers import device_registry as dr, entity, entity_registry as er from homeassistant.helpers.entity_component import async_update_entity from homeassistant.helpers.typing import UNDEFINED, UndefinedType @@ -2559,3 +2565,26 @@ async def test_reset_right_after_remove_entity_registry( assert len(ent.remove_calls) == 1 assert hass.states.get("test.test") is None + + +async def test_get_hassjob_type(hass: HomeAssistant) -> None: + """Test get_hassjob_type.""" + + class AsyncEntity(entity.Entity): + """Test entity.""" + + def update(self): + """Test update Executor.""" + + async def async_update(self): + """Test update Coroutinefunction.""" + + @callback + def update_callback(self): + """Test update Callback.""" + + ent_1 = AsyncEntity() + + assert ent_1.get_hassjob_type("update") is HassJobType.Executor + assert ent_1.get_hassjob_type("async_update") is HassJobType.Coroutinefunction + assert ent_1.get_hassjob_type("update_callback") is HassJobType.Callback