From b3b83b7bb26a6c2de315cf63971f0accd0ea1aef Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 6 Apr 2023 18:55:16 +0200 Subject: [PATCH] Add a pipeline store to voice_assistant (#90844) * Add a pipeline store to voice_assistant * Improve error handling * Improve test coverage * Improve test coverage * Use StorageCollectionWebsocket * Correct rebase --- .../components/voice_assistant/__init__.py | 5 +- .../components/voice_assistant/const.py | 1 - .../components/voice_assistant/pipeline.py | 110 +++++++-- .../voice_assistant/websocket_api.py | 2 +- tests/components/voice_assistant/conftest.py | 2 +- tests/components/voice_assistant/test_init.py | 2 +- .../voice_assistant/test_pipeline.py | 104 +++++++++ .../voice_assistant/test_websocket.py | 217 +++++++++++++++++- 8 files changed, 420 insertions(+), 23 deletions(-) create mode 100644 tests/components/voice_assistant/test_pipeline.py diff --git a/homeassistant/components/voice_assistant/__init__.py b/homeassistant/components/voice_assistant/__init__.py index 8a2c04d8301..4edeb1e6bcd 100644 --- a/homeassistant/components/voice_assistant/__init__.py +++ b/homeassistant/components/voice_assistant/__init__.py @@ -17,6 +17,7 @@ from .pipeline import ( PipelineRun, PipelineStage, async_get_pipeline, + async_setup_pipeline_store, ) from .websocket_api import async_register_websocket_api @@ -31,7 +32,7 @@ __all__ = ( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up Voice Assistant integration.""" - hass.data[DOMAIN] = {} + await async_setup_pipeline_store(hass) async_register_websocket_api(hass) return True @@ -61,7 +62,7 @@ async def async_pipeline_from_audio_stream( if context is None: context = Context() - pipeline = async_get_pipeline( + pipeline = await async_get_pipeline( hass, pipeline_id=pipeline_id, language=language, diff --git a/homeassistant/components/voice_assistant/const.py b/homeassistant/components/voice_assistant/const.py index 86572fb459f..f3006c98169 100644 --- a/homeassistant/components/voice_assistant/const.py +++ b/homeassistant/components/voice_assistant/const.py @@ -1,3 +1,2 @@ """Constants for the Voice Assistant integration.""" DOMAIN = "voice_assistant" -DEFAULT_PIPELINE = "default" diff --git a/homeassistant/components/voice_assistant/pipeline.py b/homeassistant/components/voice_assistant/pipeline.py index 7c909c32819..26cc2d2d27e 100644 --- a/homeassistant/components/voice_assistant/pipeline.py +++ b/homeassistant/components/voice_assistant/pipeline.py @@ -7,13 +7,20 @@ from dataclasses import asdict, dataclass, field import logging from typing import Any +import voluptuous as vol + from homeassistant.backports.enum import StrEnum from homeassistant.components import conversation, media_source, stt, tts from homeassistant.components.tts.media_source import ( generate_media_source_id as tts_generate_media_source_id, ) from homeassistant.core import Context, HomeAssistant, callback -from homeassistant.util.dt import utcnow +from homeassistant.helpers.collection import ( + StorageCollection, + StorageCollectionWebsocket, +) +from homeassistant.helpers.storage import Store +from homeassistant.util import dt as dt_util, ulid as ulid_util from .const import DOMAIN from .error import ( @@ -25,23 +32,39 @@ from .error import ( _LOGGER = logging.getLogger(__name__) +STORAGE_KEY = f"{DOMAIN}.pipelines" +STORAGE_VERSION = 1 -@callback -def async_get_pipeline( +STORAGE_FIELDS = { + vol.Required("conversation_engine"): str, + vol.Required("language"): str, + vol.Required("name"): str, + vol.Required("stt_engine"): str, + vol.Required("tts_engine"): str, +} + +SAVE_DELAY = 10 + + +async def async_get_pipeline( hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None ) -> Pipeline | None: """Get a pipeline by id or create one for a language.""" + pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + if pipeline_id is not None: - return hass.data[DOMAIN].get(pipeline_id) + return pipeline_store.data.get(pipeline_id) # Construct a pipeline for the required/configured language language = language or hass.config.language - return Pipeline( - name=language, - language=language, - stt_engine=None, # first engine - conversation_engine=None, # first agent - tts_engine=None, # first engine + return await pipeline_store.async_create_item( + { + "name": language, + "language": language, + "stt_engine": None, # first engine + "conversation_engine": None, # first agent + "tts_engine": None, # first engine + } ) @@ -65,7 +88,7 @@ class PipelineEvent: type: PipelineEventType data: dict[str, Any] | None = None - timestamp: str = field(default_factory=lambda: utcnow().isoformat()) + timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat()) def as_dict(self) -> dict[str, Any]: """Return a dict representation of the event.""" @@ -79,16 +102,29 @@ class PipelineEvent: PipelineEventCallback = Callable[[PipelineEvent], None] -@dataclass +@dataclass(frozen=True) class Pipeline: """A voice assistant pipeline.""" - name: str - language: str | None - stt_engine: str | None conversation_engine: str | None + language: str | None + name: str + stt_engine: str | None tts_engine: str | None + id: str = field(default_factory=ulid_util.ulid) + + def to_json(self) -> dict[str, Any]: + """Return a JSON serializable representation for storage.""" + return { + "conversation_engine": self.conversation_engine, + "id": self.id, + "language": self.language, + "name": self.name, + "stt_engine": self.stt_engine, + "tts_engine": self.tts_engine, + } + class PipelineStage(StrEnum): """Stages of a pipeline.""" @@ -478,3 +514,47 @@ class PipelineInput: if prepare_tasks: await asyncio.gather(*prepare_tasks) + + +class PipelineStorageCollection(StorageCollection[Pipeline]): + """Pipeline storage collection.""" + + CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) + + async def _process_create_data(self, data: dict) -> dict: + """Validate the config is valid.""" + # We don't need to validate, the WS API has already validated + return data + + @callback + def _get_suggested_id(self, info: dict) -> str: + """Suggest an ID based on the config.""" + return ulid_util.ulid() + + async def _update_data(self, item: Pipeline, update_data: dict) -> Pipeline: + """Return a new updated item.""" + return Pipeline(id=item.id, **update_data) + + def _create_item(self, item_id: str, data: dict) -> Pipeline: + """Create an item from validated config.""" + return Pipeline(id=item_id, **data) + + def _deserialize_item(self, data: dict) -> Pipeline: + """Create an item from its serialized representation.""" + return Pipeline(**data) + + def _serialize_item(self, item_id: str, item: Pipeline) -> dict: + """Return the serialized representation of an item.""" + return item.to_json() + + +async def async_setup_pipeline_store(hass): + """Set up the pipeline storage collection.""" + pipeline_store = PipelineStorageCollection( + Store(hass, STORAGE_VERSION, STORAGE_KEY) + ) + await pipeline_store.async_load() + StorageCollectionWebsocket( + pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS + ).async_setup(hass) + hass.data[DOMAIN] = pipeline_store diff --git a/homeassistant/components/voice_assistant/websocket_api.py b/homeassistant/components/voice_assistant/websocket_api.py index 0df13fc19ea..42c22bfbed5 100644 --- a/homeassistant/components/voice_assistant/websocket_api.py +++ b/homeassistant/components/voice_assistant/websocket_api.py @@ -61,7 +61,7 @@ async def websocket_run( language = "en-US" pipeline_id = msg.get("pipeline") - pipeline = async_get_pipeline( + pipeline = await async_get_pipeline( hass, pipeline_id=pipeline_id, language=language, diff --git a/tests/components/voice_assistant/conftest.py b/tests/components/voice_assistant/conftest.py index 86da6334e09..b768c02ec44 100644 --- a/tests/components/voice_assistant/conftest.py +++ b/tests/components/voice_assistant/conftest.py @@ -117,7 +117,7 @@ async def mock_stt_provider(hass) -> MockSttProvider: return MockSttProvider(hass, _TRANSCRIPT) -@pytest.fixture(autouse=True) +@pytest.fixture async def init_components( hass: HomeAssistant, mock_stt_provider: MockSttProvider, diff --git a/tests/components/voice_assistant/test_init.py b/tests/components/voice_assistant/test_init.py index 1178f94c60c..c68aea9890c 100644 --- a/tests/components/voice_assistant/test_init.py +++ b/tests/components/voice_assistant/test_init.py @@ -7,7 +7,7 @@ from homeassistant.core import HomeAssistant async def test_pipeline_from_audio_stream( - hass: HomeAssistant, mock_stt_provider, snapshot: SnapshotAssertion + hass: HomeAssistant, mock_stt_provider, init_components, snapshot: SnapshotAssertion ) -> None: """Test creating a pipeline from an audio stream.""" diff --git a/tests/components/voice_assistant/test_pipeline.py b/tests/components/voice_assistant/test_pipeline.py new file mode 100644 index 00000000000..db1f3629483 --- /dev/null +++ b/tests/components/voice_assistant/test_pipeline.py @@ -0,0 +1,104 @@ +"""Websocket tests for Voice Assistant integration.""" +from typing import Any + +from homeassistant.components.voice_assistant.const import DOMAIN +from homeassistant.components.voice_assistant.pipeline import ( + STORAGE_KEY, + STORAGE_VERSION, + PipelineStorageCollection, +) +from homeassistant.core import HomeAssistant +from homeassistant.helpers.storage import Store +from homeassistant.setup import async_setup_component + +from tests.common import flush_store + + +async def test_load_datasets(hass: HomeAssistant, init_components) -> None: + """Make sure that we can load/save data correctly.""" + + pipelines = [ + { + "conversation_engine": "conversation_engine_1", + "language": "language_1", + "name": "name_1", + "stt_engine": "stt_engine_1", + "tts_engine": "tts_engine_1", + }, + { + "conversation_engine": "conversation_engine_2", + "language": "language_2", + "name": "name_2", + "stt_engine": "stt_engine_2", + "tts_engine": "tts_engine_2", + }, + { + "conversation_engine": "conversation_engine_3", + "language": "language_3", + "name": "name_3", + "stt_engine": "stt_engine_3", + "tts_engine": "tts_engine_3", + }, + ] + pipeline_ids = [] + + store1: PipelineStorageCollection = hass.data[DOMAIN] + for pipeline in pipelines: + pipeline_ids.append((await store1.async_create_item(pipeline)).id) + assert len(store1.data) == 3 + + await store1.async_delete_item(pipeline_ids[1]) + assert len(store1.data) == 2 + + store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY)) + await flush_store(store1.store) + await store2.async_load() + + assert len(store2.data) == 2 + + assert store1.data is not store2.data + assert store1.data == store2.data + + +async def test_loading_datasets_from_storage( + hass: HomeAssistant, hass_storage: dict[str, Any] +) -> None: + """Test loading stored datasets on start.""" + hass_storage[STORAGE_KEY] = { + "version": 1, + "minor_version": 1, + "key": "voice_assistant.pipelines", + "data": { + "items": [ + { + "conversation_engine": "conversation_engine_1", + "id": "01GX8ZWBAQYWNB1XV3EXEZ75DY", + "language": "language_1", + "name": "name_1", + "stt_engine": "stt_engine_1", + "tts_engine": "tts_engine_1", + }, + { + "conversation_engine": "conversation_engine_2", + "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX", + "language": "language_2", + "name": "name_2", + "stt_engine": "stt_engine_2", + "tts_engine": "tts_engine_2", + }, + { + "conversation_engine": "conversation_engine_3", + "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J", + "language": "language_3", + "name": "name_3", + "stt_engine": "stt_engine_3", + "tts_engine": "tts_engine_3", + }, + ] + }, + } + + assert await async_setup_component(hass, "voice_assistant", {}) + + store: PipelineStorageCollection = hass.data[DOMAIN] + assert len(store.data) == 3 diff --git a/tests/components/voice_assistant/test_websocket.py b/tests/components/voice_assistant/test_websocket.py index 08dadcdd99d..938184b607a 100644 --- a/tests/components/voice_assistant/test_websocket.py +++ b/tests/components/voice_assistant/test_websocket.py @@ -1,9 +1,14 @@ """Websocket tests for Voice Assistant integration.""" import asyncio -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch from syrupy.assertion import SnapshotAssertion +from homeassistant.components.voice_assistant.const import DOMAIN +from homeassistant.components.voice_assistant.pipeline import ( + Pipeline, + PipelineStorageCollection, +) from homeassistant.core import HomeAssistant from tests.typing import WebSocketGenerator @@ -12,6 +17,7 @@ from tests.typing import WebSocketGenerator async def test_text_only_pipeline( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, + init_components, snapshot: SnapshotAssertion, ) -> None: """Test events from a pipeline run with text input (no STT/TTS).""" @@ -51,7 +57,10 @@ async def test_text_only_pipeline( async def test_audio_pipeline( - hass: HomeAssistant, hass_ws_client: WebSocketGenerator, snapshot: SnapshotAssertion + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, ) -> None: """Test events from a pipeline run with audio input/output.""" client = await hass_ws_client(hass) @@ -271,6 +280,7 @@ async def test_audio_pipeline_timeout( async def test_stt_provider_missing( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, + init_components, snapshot: SnapshotAssertion, ) -> None: """Test events from a pipeline run with a non-existent STT provider.""" @@ -297,6 +307,7 @@ async def test_stt_provider_missing( async def test_stt_stream_failed( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, + init_components, snapshot: SnapshotAssertion, ) -> None: """Test events from a pipeline run with a non-existent STT provider.""" @@ -398,3 +409,205 @@ async def test_invalid_stage_order( # result msg = await client.receive_json() assert not msg["success"] + + +async def test_add_pipeline( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test we can add a pipeline.""" + client = await hass_ws_client(hass) + pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + + await client.send_json_auto_id( + { + "type": "voice_assistant/pipeline/create", + "conversation_engine": "test_conversation_engine", + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "conversation_engine": "test_conversation_engine", + "id": ANY, + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + + assert len(pipeline_store.data) == 1 + pipeline = pipeline_store.data[msg["result"]["id"]] + assert pipeline == Pipeline( + conversation_engine="test_conversation_engine", + id=msg["result"]["id"], + language="test_language", + name="test_name", + stt_engine="test_stt_engine", + tts_engine="test_tts_engine", + ) + + +async def test_delete_pipeline( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test we can delete a pipeline.""" + client = await hass_ws_client(hass) + pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + + await client.send_json_auto_id( + { + "type": "voice_assistant/pipeline/create", + "conversation_engine": "test_conversation_engine", + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert len(pipeline_store.data) == 1 + + pipeline_id = msg["result"]["id"] + + await client.send_json_auto_id( + { + "type": "voice_assistant/pipeline/delete", + "pipeline_id": pipeline_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert len(pipeline_store.data) == 0 + + await client.send_json_auto_id( + { + "type": "voice_assistant/pipeline/delete", + "pipeline_id": pipeline_id, + } + ) + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "not_found", + "message": f"Unable to find pipeline_id {pipeline_id}", + } + + +async def test_list_pipelines( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test we can list pipelines.""" + client = await hass_ws_client(hass) + pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + + await client.send_json_auto_id({"type": "voice_assistant/pipeline/list"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == [] + + await client.send_json_auto_id( + { + "type": "voice_assistant/pipeline/create", + "conversation_engine": "test_conversation_engine", + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert len(pipeline_store.data) == 1 + + await client.send_json_auto_id({"type": "voice_assistant/pipeline/list"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == [ + { + "conversation_engine": "test_conversation_engine", + "id": ANY, + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + ] + + +async def test_update_pipeline( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test we can list pipelines.""" + client = await hass_ws_client(hass) + pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + + await client.send_json_auto_id( + { + "type": "voice_assistant/pipeline/update", + "conversation_engine": "new_conversation_engine", + "language": "new_language", + "name": "new_name", + "pipeline_id": "no_such_pipeline", + "stt_engine": "new_stt_engine", + "tts_engine": "new_tts_engine", + } + ) + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "not_found", + "message": "Unable to find pipeline_id no_such_pipeline", + } + + await client.send_json_auto_id( + { + "type": "voice_assistant/pipeline/create", + "conversation_engine": "test_conversation_engine", + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id = msg["result"]["id"] + assert len(pipeline_store.data) == 1 + + await client.send_json_auto_id( + { + "type": "voice_assistant/pipeline/update", + "conversation_engine": "new_conversation_engine", + "language": "new_language", + "name": "new_name", + "pipeline_id": pipeline_id, + "stt_engine": "new_stt_engine", + "tts_engine": "new_tts_engine", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "conversation_engine": "new_conversation_engine", + "language": "new_language", + "name": "new_name", + "id": pipeline_id, + "stt_engine": "new_stt_engine", + "tts_engine": "new_tts_engine", + } + + assert len(pipeline_store.data) == 1 + pipeline = pipeline_store.data[pipeline_id] + assert pipeline == Pipeline( + conversation_engine="new_conversation_engine", + id=pipeline_id, + language="new_language", + name="new_name", + stt_engine="new_stt_engine", + tts_engine="new_tts_engine", + )