From c63cab336c17ae8179026601c196603568d32be2 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 16 Sep 2024 07:50:43 -0500 Subject: [PATCH] Change wake word interception to a subscription (#125629) * Allow stopping intercepting wake words * Make wake word interception a subscription * Keep future * Add test for unsub --- .../assist_satellite/websocket_api.py | 19 ++- .../assist_satellite/test_websocket_api.py | 141 ++++++++++++++---- 2 files changed, 129 insertions(+), 31 deletions(-) diff --git a/homeassistant/components/assist_satellite/websocket_api.py b/homeassistant/components/assist_satellite/websocket_api.py index 10687f4210e..8de10c8a9de 100644 --- a/homeassistant/components/assist_satellite/websocket_api.py +++ b/homeassistant/components/assist_satellite/websocket_api.py @@ -6,6 +6,7 @@ import voluptuous as vol from homeassistant.components import websocket_api from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv from homeassistant.helpers.entity_component import EntityComponent @@ -42,5 +43,19 @@ async def websocket_intercept_wake_word( ) return - wake_word_phrase = await satellite.async_intercept_wake_word() - connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase}) + async def intercept_wake_word() -> None: + """Push an intercepted wake word to websocket.""" + try: + wake_word_phrase = await satellite.async_intercept_wake_word() + connection.send_message( + websocket_api.event_message( + msg["id"], + {"wake_word_phrase": wake_word_phrase}, + ) + ) + except HomeAssistantError as err: + connection.send_error(msg["id"], "home_assistant_error", str(err)) + + task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word") + connection.subscriptions[msg["id"]] = task.cancel + connection.send_message(websocket_api.result_message(msg["id"])) diff --git a/tests/components/assist_satellite/test_websocket_api.py b/tests/components/assist_satellite/test_websocket_api.py index af49334e629..7895ea2555a 100644 --- a/tests/components/assist_satellite/test_websocket_api.py +++ b/tests/components/assist_satellite/test_websocket_api.py @@ -1,6 +1,9 @@ """Test WebSocket API.""" import asyncio +from unittest.mock import patch + +import pytest from homeassistant.components.assist_pipeline import PipelineStage from homeassistant.config_entries import ConfigEntry @@ -28,20 +31,23 @@ async def test_intercept_wake_word( "entity_id": ENTITY_ID, } ) - - for _ in range(3): - await asyncio.sleep(0) + msg = await ws_client.receive_json() + assert msg["success"] + assert msg["result"] is None + subscription_id = msg["id"] await entity.async_accept_pipeline_from_satellite( - object(), + object(), # type: ignore[arg-type] start_stage=PipelineStage.STT, wake_word_phrase="ok, nabu", ) - response = await ws_client.receive_json() + async with asyncio.timeout(1): + msg = await ws_client.receive_json() - assert response["success"] - assert response["result"] == {"wake_word_phrase": "ok, nabu"} + assert msg["id"] == subscription_id + assert msg["type"] == "event" + assert msg["event"] == {"wake_word_phrase": "ok, nabu"} async def test_intercept_wake_word_requires_on_device_wake_word( @@ -60,18 +66,23 @@ async def test_intercept_wake_word_requires_on_device_wake_word( } ) - for _ in range(3): - await asyncio.sleep(0) + async with asyncio.timeout(1): + msg = await ws_client.receive_json() + + assert msg["success"] + assert msg["result"] is None await entity.async_accept_pipeline_from_satellite( - object(), + object(), # type: ignore[arg-type] # Emulate wake word processing in Home Assistant start_stage=PipelineStage.WAKE_WORD, ) - response = await ws_client.receive_json() - assert not response["success"] - assert response["error"] == { + async with asyncio.timeout(1): + msg = await ws_client.receive_json() + + assert not msg["success"] + assert msg["error"] == { "code": "home_assistant_error", "message": "Only on-device wake words currently supported", } @@ -93,18 +104,23 @@ async def test_intercept_wake_word_requires_wake_word_phrase( } ) - for _ in range(3): - await asyncio.sleep(0) + async with asyncio.timeout(1): + msg = await ws_client.receive_json() + + assert msg["success"] + assert msg["result"] is None await entity.async_accept_pipeline_from_satellite( - object(), + object(), # type: ignore[arg-type] start_stage=PipelineStage.STT, # We are not passing wake word phrase ) - response = await ws_client.receive_json() - assert not response["success"] - assert response["error"] == { + async with asyncio.timeout(1): + msg = await ws_client.receive_json() + + assert not msg["success"] + assert msg["error"] == { "code": "home_assistant_error", "message": "No wake word phrase provided", } @@ -128,10 +144,12 @@ async def test_intercept_wake_word_require_admin( "entity_id": ENTITY_ID, } ) - response = await ws_client.receive_json() - assert not response["success"] - assert response["error"] == { + async with asyncio.timeout(1): + msg = await ws_client.receive_json() + + assert not msg["success"] + assert msg["error"] == { "code": "unauthorized", "message": "Unauthorized", } @@ -152,10 +170,11 @@ async def test_intercept_wake_word_invalid_satellite( "entity_id": "assist_satellite.invalid", } ) - response = await ws_client.receive_json() + async with asyncio.timeout(1): + msg = await ws_client.receive_json() - assert not response["success"] - assert response["error"] == { + assert not msg["success"] + assert msg["error"] == { "code": "not_found", "message": "Entity not found", } @@ -167,7 +186,7 @@ async def test_intercept_wake_word_twice( entity: MockAssistSatellite, hass_ws_client: WebSocketGenerator, ) -> None: - """Test intercepting a wake word requires admin access.""" + """Test intercepting a wake word twice cancels the previous request.""" ws_client = await hass_ws_client(hass) await ws_client.send_json_auto_id( @@ -177,16 +196,80 @@ async def test_intercept_wake_word_twice( } ) + async with asyncio.timeout(1): + msg = await ws_client.receive_json() + + assert msg["success"] + assert msg["result"] is None + + task = hass.async_create_task(ws_client.receive_json()) + await ws_client.send_json_auto_id( { "type": "assist_satellite/intercept_wake_word", "entity_id": ENTITY_ID, } ) - response = await ws_client.receive_json() - assert not response["success"] - assert response["error"] == { + # Should get an error from previous subscription + async with asyncio.timeout(1): + msg = await task + + assert not msg["success"] + assert msg["error"] == { "code": "home_assistant_error", "message": "Wake word interception already in progress", } + + # Response to second subscription + async with asyncio.timeout(1): + msg = await ws_client.receive_json() + + assert msg["success"] + assert msg["result"] is None + + +async def test_intercept_wake_word_unsubscribe( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test that closing the websocket connection stops interception.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/intercept_wake_word", + "entity_id": ENTITY_ID, + } + ) + + # Wait for interception to start + for _ in range(3): + await asyncio.sleep(0) + + async def receive_json(): + with pytest.raises(TypeError): + # Raises TypeError when connection is closed + await ws_client.receive_json() + + task = hass.async_create_task(receive_json()) + + # Close connection + await ws_client.close() + await task + + with ( + patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + ) as mock_pipeline_from_audio_stream, + ): + # Start a pipeline with a wake word + await entity.async_accept_pipeline_from_satellite( + object(), + wake_word_phrase="ok, nabu", # type: ignore[arg-type] + ) + + # Wake word should not be intercepted + mock_pipeline_from_audio_stream.assert_called_once()