Fix inner callback decorators with partials (#102873)

This commit is contained in:
J. Nick Koston 2023-10-28 08:38:42 -05:00 committed by GitHub
parent 524e20536d
commit 009dc91b97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 21 deletions

View File

@ -92,6 +92,7 @@ def pong_message(iden: int) -> dict[str, Any]:
return {"id": iden, "type": "pong"}
@callback
def _forward_events_check_permissions(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
user: User,
@ -109,6 +110,7 @@ def _forward_events_check_permissions(
send_message(messages.cached_event_message(msg_id, event))
@callback
def _forward_events_unconditional(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
msg_id: int,
@ -135,17 +137,15 @@ def handle_subscribe_events(
raise Unauthorized
if event_type == EVENT_STATE_CHANGED:
forward_events = callback(
partial(
_forward_events_check_permissions,
connection.send_message,
connection.user,
msg["id"],
)
forward_events = partial(
_forward_events_check_permissions,
connection.send_message,
connection.user,
msg["id"],
)
else:
forward_events = callback(
partial(_forward_events_unconditional, connection.send_message, msg["id"])
forward_events = partial(
_forward_events_unconditional, connection.send_message, msg["id"]
)
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
@ -298,6 +298,7 @@ def _send_handle_get_states_response(
connection.send_message(construct_result_message(msg_id, f"[{joined_states}]"))
@callback
def _forward_entity_changes(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
entity_ids: set[str],
@ -337,14 +338,12 @@ def handle_subscribe_entities(
states = _async_get_allowed_states(hass, connection)
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
EVENT_STATE_CHANGED,
callback(
partial(
_forward_entity_changes,
connection.send_message,
entity_ids,
connection.user,
msg["id"],
)
partial(
_forward_entity_changes,
connection.send_message,
entity_ids,
connection.user,
msg["id"],
),
run_immediately=True,
)

View File

@ -209,6 +209,18 @@ def is_callback(func: Callable[..., Any]) -> bool:
return getattr(func, "_hass_callback", False) is True
def is_callback_check_partial(target: Callable[..., Any]) -> bool:
"""Check if function is safe to be called in the event loop.
This version of is_callback will also check if the target is a partial
and walk the chain of partials to find the original function.
"""
check_target = target
while isinstance(check_target, functools.partial):
check_target = check_target.func
return is_callback(check_target)
class _Hass(threading.local):
"""Container which makes a HomeAssistant instance available to the event loop."""
@ -1141,9 +1153,9 @@ class EventBus:
This method must be run in the event loop.
"""
if event_filter is not None and not is_callback(event_filter):
if event_filter is not None and not is_callback_check_partial(event_filter):
raise HomeAssistantError(f"Event filter {event_filter} is not a callback")
if run_immediately and not is_callback(listener):
if run_immediately and not is_callback_check_partial(listener):
raise HomeAssistantError(f"Event listener {listener} is not a callback")
return self._async_listen_filterable_job(
event_type,

View File

@ -394,8 +394,8 @@ def _async_track_event(
if listeners_key not in hass_data:
hass_data[listeners_key] = hass.bus.async_listen(
event_type,
callback(ft.partial(dispatcher_callable, hass, callbacks)),
event_filter=callback(ft.partial(filter_callable, hass, callbacks)),
ft.partial(dispatcher_callable, hass, callbacks),
event_filter=ft.partial(filter_callable, hass, callbacks),
)
job = HassJob(action, f"track {event_type} event {keys}")

View File

@ -2498,3 +2498,39 @@ async def test_get_release_channel(version: str, release_channel: str) -> None:
"""Test if release channel detection works from Home Assistant version number."""
with patch("homeassistant.core.__version__", f"{version}"):
assert get_release_channel() == release_channel
def test_is_callback_check_partial():
"""Test is_callback_check_partial matches HassJob."""
@ha.callback
def callback_func():
pass
def not_callback_func():
pass
assert ha.is_callback(callback_func)
assert HassJob(callback_func).job_type == ha.HassJobType.Callback
assert ha.is_callback_check_partial(functools.partial(callback_func))
assert HassJob(functools.partial(callback_func)).job_type == ha.HassJobType.Callback
assert ha.is_callback_check_partial(
functools.partial(functools.partial(callback_func))
)
assert HassJob(functools.partial(functools.partial(callback_func))).job_type == (
ha.HassJobType.Callback
)
assert not ha.is_callback_check_partial(not_callback_func)
assert HassJob(not_callback_func).job_type == ha.HassJobType.Executor
assert not ha.is_callback_check_partial(functools.partial(not_callback_func))
assert HassJob(functools.partial(not_callback_func)).job_type == (
ha.HassJobType.Executor
)
# We check the inner function, not the outer one
assert not ha.is_callback_check_partial(
ha.callback(functools.partial(not_callback_func))
)
assert HassJob(ha.callback(functools.partial(not_callback_func))).job_type == (
ha.HassJobType.Executor
)