mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Fix inner callback decorators with partials (#102873)
This commit is contained in:
parent
524e20536d
commit
009dc91b97
@ -92,6 +92,7 @@ def pong_message(iden: int) -> dict[str, Any]:
|
||||
return {"id": iden, "type": "pong"}
|
||||
|
||||
|
||||
@callback
|
||||
def _forward_events_check_permissions(
|
||||
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
|
||||
user: User,
|
||||
@ -109,6 +110,7 @@ def _forward_events_check_permissions(
|
||||
send_message(messages.cached_event_message(msg_id, event))
|
||||
|
||||
|
||||
@callback
|
||||
def _forward_events_unconditional(
|
||||
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
|
||||
msg_id: int,
|
||||
@ -135,17 +137,15 @@ def handle_subscribe_events(
|
||||
raise Unauthorized
|
||||
|
||||
if event_type == EVENT_STATE_CHANGED:
|
||||
forward_events = callback(
|
||||
partial(
|
||||
_forward_events_check_permissions,
|
||||
connection.send_message,
|
||||
connection.user,
|
||||
msg["id"],
|
||||
)
|
||||
forward_events = partial(
|
||||
_forward_events_check_permissions,
|
||||
connection.send_message,
|
||||
connection.user,
|
||||
msg["id"],
|
||||
)
|
||||
else:
|
||||
forward_events = callback(
|
||||
partial(_forward_events_unconditional, connection.send_message, msg["id"])
|
||||
forward_events = partial(
|
||||
_forward_events_unconditional, connection.send_message, msg["id"]
|
||||
)
|
||||
|
||||
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
|
||||
@ -298,6 +298,7 @@ def _send_handle_get_states_response(
|
||||
connection.send_message(construct_result_message(msg_id, f"[{joined_states}]"))
|
||||
|
||||
|
||||
@callback
|
||||
def _forward_entity_changes(
|
||||
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
|
||||
entity_ids: set[str],
|
||||
@ -337,14 +338,12 @@ def handle_subscribe_entities(
|
||||
states = _async_get_allowed_states(hass, connection)
|
||||
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
|
||||
EVENT_STATE_CHANGED,
|
||||
callback(
|
||||
partial(
|
||||
_forward_entity_changes,
|
||||
connection.send_message,
|
||||
entity_ids,
|
||||
connection.user,
|
||||
msg["id"],
|
||||
)
|
||||
partial(
|
||||
_forward_entity_changes,
|
||||
connection.send_message,
|
||||
entity_ids,
|
||||
connection.user,
|
||||
msg["id"],
|
||||
),
|
||||
run_immediately=True,
|
||||
)
|
||||
|
@ -209,6 +209,18 @@ def is_callback(func: Callable[..., Any]) -> bool:
|
||||
return getattr(func, "_hass_callback", False) is True
|
||||
|
||||
|
||||
def is_callback_check_partial(target: Callable[..., Any]) -> bool:
|
||||
"""Check if function is safe to be called in the event loop.
|
||||
|
||||
This version of is_callback will also check if the target is a partial
|
||||
and walk the chain of partials to find the original function.
|
||||
"""
|
||||
check_target = target
|
||||
while isinstance(check_target, functools.partial):
|
||||
check_target = check_target.func
|
||||
return is_callback(check_target)
|
||||
|
||||
|
||||
class _Hass(threading.local):
|
||||
"""Container which makes a HomeAssistant instance available to the event loop."""
|
||||
|
||||
@ -1141,9 +1153,9 @@ class EventBus:
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if event_filter is not None and not is_callback(event_filter):
|
||||
if event_filter is not None and not is_callback_check_partial(event_filter):
|
||||
raise HomeAssistantError(f"Event filter {event_filter} is not a callback")
|
||||
if run_immediately and not is_callback(listener):
|
||||
if run_immediately and not is_callback_check_partial(listener):
|
||||
raise HomeAssistantError(f"Event listener {listener} is not a callback")
|
||||
return self._async_listen_filterable_job(
|
||||
event_type,
|
||||
|
@ -394,8 +394,8 @@ def _async_track_event(
|
||||
if listeners_key not in hass_data:
|
||||
hass_data[listeners_key] = hass.bus.async_listen(
|
||||
event_type,
|
||||
callback(ft.partial(dispatcher_callable, hass, callbacks)),
|
||||
event_filter=callback(ft.partial(filter_callable, hass, callbacks)),
|
||||
ft.partial(dispatcher_callable, hass, callbacks),
|
||||
event_filter=ft.partial(filter_callable, hass, callbacks),
|
||||
)
|
||||
|
||||
job = HassJob(action, f"track {event_type} event {keys}")
|
||||
|
@ -2498,3 +2498,39 @@ async def test_get_release_channel(version: str, release_channel: str) -> None:
|
||||
"""Test if release channel detection works from Home Assistant version number."""
|
||||
with patch("homeassistant.core.__version__", f"{version}"):
|
||||
assert get_release_channel() == release_channel
|
||||
|
||||
|
||||
def test_is_callback_check_partial():
|
||||
"""Test is_callback_check_partial matches HassJob."""
|
||||
|
||||
@ha.callback
|
||||
def callback_func():
|
||||
pass
|
||||
|
||||
def not_callback_func():
|
||||
pass
|
||||
|
||||
assert ha.is_callback(callback_func)
|
||||
assert HassJob(callback_func).job_type == ha.HassJobType.Callback
|
||||
assert ha.is_callback_check_partial(functools.partial(callback_func))
|
||||
assert HassJob(functools.partial(callback_func)).job_type == ha.HassJobType.Callback
|
||||
assert ha.is_callback_check_partial(
|
||||
functools.partial(functools.partial(callback_func))
|
||||
)
|
||||
assert HassJob(functools.partial(functools.partial(callback_func))).job_type == (
|
||||
ha.HassJobType.Callback
|
||||
)
|
||||
assert not ha.is_callback_check_partial(not_callback_func)
|
||||
assert HassJob(not_callback_func).job_type == ha.HassJobType.Executor
|
||||
assert not ha.is_callback_check_partial(functools.partial(not_callback_func))
|
||||
assert HassJob(functools.partial(not_callback_func)).job_type == (
|
||||
ha.HassJobType.Executor
|
||||
)
|
||||
|
||||
# We check the inner function, not the outer one
|
||||
assert not ha.is_callback_check_partial(
|
||||
ha.callback(functools.partial(not_callback_func))
|
||||
)
|
||||
assert HassJob(ha.callback(functools.partial(not_callback_func))).job_type == (
|
||||
ha.HassJobType.Executor
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user