From 009dc91b97c6dfd5a134b49844f44d3d537da1df Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 28 Oct 2023 08:38:42 -0500 Subject: [PATCH] Fix inner callback decorators with partials (#102873) --- .../components/websocket_api/commands.py | 33 +++++++++-------- homeassistant/core.py | 16 +++++++-- homeassistant/helpers/event.py | 4 +-- tests/test_core.py | 36 +++++++++++++++++++ 4 files changed, 68 insertions(+), 21 deletions(-) diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index a29bee86116..b69ff57d015 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -92,6 +92,7 @@ def pong_message(iden: int) -> dict[str, Any]: return {"id": iden, "type": "pong"} +@callback def _forward_events_check_permissions( send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None], user: User, @@ -109,6 +110,7 @@ def _forward_events_check_permissions( send_message(messages.cached_event_message(msg_id, event)) +@callback def _forward_events_unconditional( send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None], msg_id: int, @@ -135,17 +137,15 @@ def handle_subscribe_events( raise Unauthorized if event_type == EVENT_STATE_CHANGED: - forward_events = callback( - partial( - _forward_events_check_permissions, - connection.send_message, - connection.user, - msg["id"], - ) + forward_events = partial( + _forward_events_check_permissions, + connection.send_message, + connection.user, + msg["id"], ) else: - forward_events = callback( - partial(_forward_events_unconditional, connection.send_message, msg["id"]) + forward_events = partial( + _forward_events_unconditional, connection.send_message, msg["id"] ) connection.subscriptions[msg["id"]] = hass.bus.async_listen( @@ -298,6 +298,7 @@ def _send_handle_get_states_response( connection.send_message(construct_result_message(msg_id, f"[{joined_states}]")) +@callback def _forward_entity_changes( send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None], entity_ids: set[str], @@ -337,14 +338,12 @@ def handle_subscribe_entities( states = _async_get_allowed_states(hass, connection) connection.subscriptions[msg["id"]] = hass.bus.async_listen( EVENT_STATE_CHANGED, - callback( - partial( - _forward_entity_changes, - connection.send_message, - entity_ids, - connection.user, - msg["id"], - ) + partial( + _forward_entity_changes, + connection.send_message, + entity_ids, + connection.user, + msg["id"], ), run_immediately=True, ) diff --git a/homeassistant/core.py b/homeassistant/core.py index 2025d813be4..48cc70e7727 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -209,6 +209,18 @@ def is_callback(func: Callable[..., Any]) -> bool: return getattr(func, "_hass_callback", False) is True +def is_callback_check_partial(target: Callable[..., Any]) -> bool: + """Check if function is safe to be called in the event loop. + + This version of is_callback will also check if the target is a partial + and walk the chain of partials to find the original function. + """ + check_target = target + while isinstance(check_target, functools.partial): + check_target = check_target.func + return is_callback(check_target) + + class _Hass(threading.local): """Container which makes a HomeAssistant instance available to the event loop.""" @@ -1141,9 +1153,9 @@ class EventBus: This method must be run in the event loop. """ - if event_filter is not None and not is_callback(event_filter): + if event_filter is not None and not is_callback_check_partial(event_filter): raise HomeAssistantError(f"Event filter {event_filter} is not a callback") - if run_immediately and not is_callback(listener): + if run_immediately and not is_callback_check_partial(listener): raise HomeAssistantError(f"Event listener {listener} is not a callback") return self._async_listen_filterable_job( event_type, diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 75e2340a187..ab0fc25f04d 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -394,8 +394,8 @@ def _async_track_event( if listeners_key not in hass_data: hass_data[listeners_key] = hass.bus.async_listen( event_type, - callback(ft.partial(dispatcher_callable, hass, callbacks)), - event_filter=callback(ft.partial(filter_callable, hass, callbacks)), + ft.partial(dispatcher_callable, hass, callbacks), + event_filter=ft.partial(filter_callable, hass, callbacks), ) job = HassJob(action, f"track {event_type} event {keys}") diff --git a/tests/test_core.py b/tests/test_core.py index 957da634dce..9fed1141a76 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2498,3 +2498,39 @@ async def test_get_release_channel(version: str, release_channel: str) -> None: """Test if release channel detection works from Home Assistant version number.""" with patch("homeassistant.core.__version__", f"{version}"): assert get_release_channel() == release_channel + + +def test_is_callback_check_partial(): + """Test is_callback_check_partial matches HassJob.""" + + @ha.callback + def callback_func(): + pass + + def not_callback_func(): + pass + + assert ha.is_callback(callback_func) + assert HassJob(callback_func).job_type == ha.HassJobType.Callback + assert ha.is_callback_check_partial(functools.partial(callback_func)) + assert HassJob(functools.partial(callback_func)).job_type == ha.HassJobType.Callback + assert ha.is_callback_check_partial( + functools.partial(functools.partial(callback_func)) + ) + assert HassJob(functools.partial(functools.partial(callback_func))).job_type == ( + ha.HassJobType.Callback + ) + assert not ha.is_callback_check_partial(not_callback_func) + assert HassJob(not_callback_func).job_type == ha.HassJobType.Executor + assert not ha.is_callback_check_partial(functools.partial(not_callback_func)) + assert HassJob(functools.partial(not_callback_func)).job_type == ( + ha.HassJobType.Executor + ) + + # We check the inner function, not the outer one + assert not ha.is_callback_check_partial( + ha.callback(functools.partial(not_callback_func)) + ) + assert HassJob(ha.callback(functools.partial(not_callback_func))).job_type == ( + ha.HassJobType.Executor + )