mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 09:47:52 +00:00
Fix inner callback decorators with partials (#102873)
This commit is contained in:
parent
524e20536d
commit
009dc91b97
@ -92,6 +92,7 @@ def pong_message(iden: int) -> dict[str, Any]:
|
|||||||
return {"id": iden, "type": "pong"}
|
return {"id": iden, "type": "pong"}
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def _forward_events_check_permissions(
|
def _forward_events_check_permissions(
|
||||||
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
|
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
|
||||||
user: User,
|
user: User,
|
||||||
@ -109,6 +110,7 @@ def _forward_events_check_permissions(
|
|||||||
send_message(messages.cached_event_message(msg_id, event))
|
send_message(messages.cached_event_message(msg_id, event))
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def _forward_events_unconditional(
|
def _forward_events_unconditional(
|
||||||
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
|
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
|
||||||
msg_id: int,
|
msg_id: int,
|
||||||
@ -135,17 +137,15 @@ def handle_subscribe_events(
|
|||||||
raise Unauthorized
|
raise Unauthorized
|
||||||
|
|
||||||
if event_type == EVENT_STATE_CHANGED:
|
if event_type == EVENT_STATE_CHANGED:
|
||||||
forward_events = callback(
|
forward_events = partial(
|
||||||
partial(
|
_forward_events_check_permissions,
|
||||||
_forward_events_check_permissions,
|
connection.send_message,
|
||||||
connection.send_message,
|
connection.user,
|
||||||
connection.user,
|
msg["id"],
|
||||||
msg["id"],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
forward_events = callback(
|
forward_events = partial(
|
||||||
partial(_forward_events_unconditional, connection.send_message, msg["id"])
|
_forward_events_unconditional, connection.send_message, msg["id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
|
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
|
||||||
@ -298,6 +298,7 @@ def _send_handle_get_states_response(
|
|||||||
connection.send_message(construct_result_message(msg_id, f"[{joined_states}]"))
|
connection.send_message(construct_result_message(msg_id, f"[{joined_states}]"))
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def _forward_entity_changes(
|
def _forward_entity_changes(
|
||||||
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
|
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
|
||||||
entity_ids: set[str],
|
entity_ids: set[str],
|
||||||
@ -337,14 +338,12 @@ def handle_subscribe_entities(
|
|||||||
states = _async_get_allowed_states(hass, connection)
|
states = _async_get_allowed_states(hass, connection)
|
||||||
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
|
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
|
||||||
EVENT_STATE_CHANGED,
|
EVENT_STATE_CHANGED,
|
||||||
callback(
|
partial(
|
||||||
partial(
|
_forward_entity_changes,
|
||||||
_forward_entity_changes,
|
connection.send_message,
|
||||||
connection.send_message,
|
entity_ids,
|
||||||
entity_ids,
|
connection.user,
|
||||||
connection.user,
|
msg["id"],
|
||||||
msg["id"],
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
run_immediately=True,
|
run_immediately=True,
|
||||||
)
|
)
|
||||||
|
@ -209,6 +209,18 @@ def is_callback(func: Callable[..., Any]) -> bool:
|
|||||||
return getattr(func, "_hass_callback", False) is True
|
return getattr(func, "_hass_callback", False) is True
|
||||||
|
|
||||||
|
|
||||||
|
def is_callback_check_partial(target: Callable[..., Any]) -> bool:
|
||||||
|
"""Check if function is safe to be called in the event loop.
|
||||||
|
|
||||||
|
This version of is_callback will also check if the target is a partial
|
||||||
|
and walk the chain of partials to find the original function.
|
||||||
|
"""
|
||||||
|
check_target = target
|
||||||
|
while isinstance(check_target, functools.partial):
|
||||||
|
check_target = check_target.func
|
||||||
|
return is_callback(check_target)
|
||||||
|
|
||||||
|
|
||||||
class _Hass(threading.local):
|
class _Hass(threading.local):
|
||||||
"""Container which makes a HomeAssistant instance available to the event loop."""
|
"""Container which makes a HomeAssistant instance available to the event loop."""
|
||||||
|
|
||||||
@ -1141,9 +1153,9 @@ class EventBus:
|
|||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
if event_filter is not None and not is_callback(event_filter):
|
if event_filter is not None and not is_callback_check_partial(event_filter):
|
||||||
raise HomeAssistantError(f"Event filter {event_filter} is not a callback")
|
raise HomeAssistantError(f"Event filter {event_filter} is not a callback")
|
||||||
if run_immediately and not is_callback(listener):
|
if run_immediately and not is_callback_check_partial(listener):
|
||||||
raise HomeAssistantError(f"Event listener {listener} is not a callback")
|
raise HomeAssistantError(f"Event listener {listener} is not a callback")
|
||||||
return self._async_listen_filterable_job(
|
return self._async_listen_filterable_job(
|
||||||
event_type,
|
event_type,
|
||||||
|
@ -394,8 +394,8 @@ def _async_track_event(
|
|||||||
if listeners_key not in hass_data:
|
if listeners_key not in hass_data:
|
||||||
hass_data[listeners_key] = hass.bus.async_listen(
|
hass_data[listeners_key] = hass.bus.async_listen(
|
||||||
event_type,
|
event_type,
|
||||||
callback(ft.partial(dispatcher_callable, hass, callbacks)),
|
ft.partial(dispatcher_callable, hass, callbacks),
|
||||||
event_filter=callback(ft.partial(filter_callable, hass, callbacks)),
|
event_filter=ft.partial(filter_callable, hass, callbacks),
|
||||||
)
|
)
|
||||||
|
|
||||||
job = HassJob(action, f"track {event_type} event {keys}")
|
job = HassJob(action, f"track {event_type} event {keys}")
|
||||||
|
@ -2498,3 +2498,39 @@ async def test_get_release_channel(version: str, release_channel: str) -> None:
|
|||||||
"""Test if release channel detection works from Home Assistant version number."""
|
"""Test if release channel detection works from Home Assistant version number."""
|
||||||
with patch("homeassistant.core.__version__", f"{version}"):
|
with patch("homeassistant.core.__version__", f"{version}"):
|
||||||
assert get_release_channel() == release_channel
|
assert get_release_channel() == release_channel
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_callback_check_partial():
|
||||||
|
"""Test is_callback_check_partial matches HassJob."""
|
||||||
|
|
||||||
|
@ha.callback
|
||||||
|
def callback_func():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def not_callback_func():
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert ha.is_callback(callback_func)
|
||||||
|
assert HassJob(callback_func).job_type == ha.HassJobType.Callback
|
||||||
|
assert ha.is_callback_check_partial(functools.partial(callback_func))
|
||||||
|
assert HassJob(functools.partial(callback_func)).job_type == ha.HassJobType.Callback
|
||||||
|
assert ha.is_callback_check_partial(
|
||||||
|
functools.partial(functools.partial(callback_func))
|
||||||
|
)
|
||||||
|
assert HassJob(functools.partial(functools.partial(callback_func))).job_type == (
|
||||||
|
ha.HassJobType.Callback
|
||||||
|
)
|
||||||
|
assert not ha.is_callback_check_partial(not_callback_func)
|
||||||
|
assert HassJob(not_callback_func).job_type == ha.HassJobType.Executor
|
||||||
|
assert not ha.is_callback_check_partial(functools.partial(not_callback_func))
|
||||||
|
assert HassJob(functools.partial(not_callback_func)).job_type == (
|
||||||
|
ha.HassJobType.Executor
|
||||||
|
)
|
||||||
|
|
||||||
|
# We check the inner function, not the outer one
|
||||||
|
assert not ha.is_callback_check_partial(
|
||||||
|
ha.callback(functools.partial(not_callback_func))
|
||||||
|
)
|
||||||
|
assert HassJob(ha.callback(functools.partial(not_callback_func))).job_type == (
|
||||||
|
ha.HassJobType.Executor
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user