diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index af368b909b7..49be0aac705 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -237,7 +237,9 @@ class Recorder(threading.Thread): def async_initialize(self) -> None: """Initialize the recorder.""" self._event_listener = self.hass.bus.async_listen( - MATCH_ALL, self.event_listener, event_filter=self._async_event_filter + MATCH_ALL, + self.event_listener, + run_immediately=True, ) self._queue_watcher = async_track_time_interval( self.hass, self._async_check_queue, timedelta(minutes=10) @@ -916,7 +918,8 @@ class Recorder(threading.Thread): @callback def event_listener(self, event: Event) -> None: """Listen for new events and put them in the process queue.""" - self.queue_task(EventTask(event)) + if self._async_event_filter(event): + self.queue_task(EventTask(event)) def block_till_done(self) -> None: """Block till all events processed. diff --git a/homeassistant/components/websocket_api/auth.py b/homeassistant/components/websocket_api/auth.py index 794dae77153..9c074588a17 100644 --- a/homeassistant/components/websocket_api/auth.py +++ b/homeassistant/components/websocket_api/auth.py @@ -56,7 +56,7 @@ class AuthPhase: self, logger: WebSocketAdapter, hass: HomeAssistant, - send_message: Callable[[str | dict[str, Any]], None], + send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None], cancel_ws: CALLBACK_TYPE, request: Request, ) -> None: diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 414913be002..edec628fd2c 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -105,17 +105,21 @@ def handle_subscribe_events( ): return - connection.send_message(messages.cached_event_message(msg["id"], event)) + connection.send_message( + lambda: messages.cached_event_message(msg["id"], event) + ) else: @callback def forward_events(event: Event) -> None: """Forward events to websocket.""" - connection.send_message(messages.cached_event_message(msg["id"], event)) + connection.send_message( + lambda: messages.cached_event_message(msg["id"], event) + ) connection.subscriptions[msg["id"]] = hass.bus.async_listen( - event_type, forward_events + event_type, forward_events, run_immediately=True ) connection.send_result(msg["id"]) @@ -286,14 +290,16 @@ def handle_subscribe_entities( if entity_ids and event.data["entity_id"] not in entity_ids: return - connection.send_message(messages.cached_state_diff_message(msg["id"], event)) + connection.send_message( + lambda: messages.cached_state_diff_message(msg["id"], event) + ) # We must never await between sending the states and listening for # state changed events or we will introduce a race condition # where some states are missed states = _async_get_allowed_states(hass, connection) connection.subscriptions[msg["id"]] = hass.bus.async_listen( - "state_changed", forward_entity_changes + EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True ) connection.send_result(msg["id"]) data: dict[str, dict[str, dict]] = { diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 9b53f358b85..cdcee408070 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -30,7 +30,7 @@ class ActiveConnection: self, logger: WebSocketAdapter, hass: HomeAssistant, - send_message: Callable[[str | dict[str, Any]], None], + send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None], user: User, refresh_token: RefreshToken, ) -> None: diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index 44b2aa8579c..a913b81c384 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -72,9 +72,13 @@ class WebSocketHandler: # Exceptions if Socket disconnected or cancelled by connection handler with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): while not self.wsock.closed: - if (message := await self._to_write.get()) is None: + if (process := await self._to_write.get()) is None: break + if not isinstance(process, str): + message: str = process() + else: + message = process self._logger.debug("Sending %s", message) await self.wsock.send_str(message) @@ -84,14 +88,14 @@ class WebSocketHandler: self._peak_checker_unsub = None @callback - def _send_message(self, message: str | dict[str, Any]) -> None: + def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None: """Send a message to the client. Closes connection if the client is not reading the messages. Async friendly. """ - if not isinstance(message, str): + if isinstance(message, dict): message = message_to_json(message) try: diff --git a/homeassistant/core.py b/homeassistant/core.py index b181e4c4106..916dd0c6f72 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -778,6 +778,7 @@ class _FilterableJob(NamedTuple): job: HassJob[None | Awaitable[None]] event_filter: Callable[[Event], bool] | None + run_immediately: bool class EventBus: @@ -845,7 +846,7 @@ class EventBus: if not listeners: return - for job, event_filter in listeners: + for job, event_filter, run_immediately in listeners: if event_filter is not None: try: if not event_filter(event): @@ -853,7 +854,13 @@ class EventBus: except Exception: # pylint: disable=broad-except _LOGGER.exception("Error in event filter") continue - self._hass.async_add_hass_job(job, event) + if run_immediately: + try: + job.target(event) + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Error running job: %s", job) + else: + self._hass.async_add_hass_job(job, event) def listen( self, @@ -881,6 +888,7 @@ class EventBus: event_type: str, listener: Callable[[Event], None | Awaitable[None]], event_filter: Callable[[Event], bool] | None = None, + run_immediately: bool = False, ) -> CALLBACK_TYPE: """Listen for all events or events of a specific type. @@ -891,12 +899,18 @@ class EventBus: @callback that returns a boolean value, determines if the listener callable should run. + If run_immediately is passed, the callback will be run + right away instead of using call_soon. Only use this if + the callback results in scheduling another task. + This method must be run in the event loop. """ if event_filter is not None and not is_callback(event_filter): raise HomeAssistantError(f"Event filter {event_filter} is not a callback") + if run_immediately and not is_callback(listener): + raise HomeAssistantError(f"Event listener {listener} is not a callback") return self._async_listen_filterable_job( - event_type, _FilterableJob(HassJob(listener), event_filter) + event_type, _FilterableJob(HassJob(listener), event_filter, run_immediately) ) @callback @@ -966,7 +980,7 @@ class EventBus: _onetime_listener, listener, ("__name__", "__qualname__", "__module__"), [] ) - filterable_job = _FilterableJob(HassJob(_onetime_listener), None) + filterable_job = _FilterableJob(HassJob(_onetime_listener), None, False) return self._async_listen_filterable_job(event_type, filterable_job) diff --git a/tests/test_core.py b/tests/test_core.py index 6885063a79a..c870605fc01 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -442,6 +442,24 @@ async def test_eventbus_filtered_listener(hass): unsub() +async def test_eventbus_run_immediately(hass): + """Test we can call events immediately.""" + calls = [] + + @ha.callback + def listener(event): + """Mock listener.""" + calls.append(event) + + unsub = hass.bus.async_listen("test", listener, run_immediately=True) + + hass.bus.async_fire("test", {"event": True}) + # No async_block_till_done here + assert len(calls) == 1 + + unsub() + + async def test_eventbus_unsubscribe_listener(hass): """Test unsubscribe listener from returned function.""" calls = []