diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index f4d060ed7b8..21311e150ad 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -61,6 +61,7 @@ _LOGGER = logging.getLogger(__name__) STORAGE_KEY = f"{DOMAIN}.pipelines" STORAGE_VERSION = 1 +STORAGE_VERSION_MINOR = 2 ENGINE_LANGUAGE_PAIRS = ( ("stt_engine", "stt_language"), @@ -86,6 +87,8 @@ PIPELINE_FIELDS = { vol.Required("tts_engine"): vol.Any(str, None), vol.Required("tts_language"): vol.Any(str, None), vol.Required("tts_voice"): vol.Any(str, None), + vol.Required("wake_word_entity"): vol.Any(str, None), + vol.Required("wake_word_id"): vol.Any(str, None), } STORED_PIPELINE_RUNS = 10 @@ -111,6 +114,8 @@ async def _async_resolve_default_pipeline_settings( tts_engine = None tts_language = None tts_voice = None + wake_word_entity = None + wake_word_id = None # Find a matching language supported by the Home Assistant conversation agent conversation_languages = language_util.matches( @@ -188,6 +193,8 @@ async def _async_resolve_default_pipeline_settings( "tts_engine": tts_engine_id, "tts_language": tts_language, "tts_voice": tts_voice, + "wake_word_entity": wake_word_entity, + "wake_word_id": wake_word_id, } @@ -295,6 +302,8 @@ class Pipeline: tts_engine: str | None tts_language: str | None tts_voice: str | None + wake_word_entity: str | None + wake_word_id: str | None id: str = field(default_factory=ulid_util.ulid) @@ -316,6 +325,8 @@ class Pipeline: tts_engine=data["tts_engine"], tts_language=data["tts_language"], tts_voice=data["tts_voice"], + wake_word_entity=data["wake_word_entity"], + wake_word_id=data["wake_word_id"], ) def to_json(self) -> dict[str, Any]: @@ -331,6 +342,8 @@ class Pipeline: "tts_engine": self.tts_engine, "tts_language": self.tts_language, "tts_voice": self.tts_voice, + "wake_word_entity": self.wake_word_entity, + "wake_word_id": self.wake_word_id, } @@ -1382,11 +1395,35 @@ class PipelineRunDebug: ) +class PipelineStore(Store[SerializedPipelineStorageCollection]): + """Store entity registry data.""" + + async def _async_migrate_func( + self, + old_major_version: int, + old_minor_version: int, + old_data: SerializedPipelineStorageCollection, + ) -> SerializedPipelineStorageCollection: + """Migrate to the new version.""" + if old_major_version == 1 and old_minor_version < 2: + # Version 1.2 adds wake word configuration + for pipeline in old_data["items"]: + # Populate keys which were introduced before version 1.2 + pipeline.setdefault("wake_word_entity", None) + pipeline.setdefault("wake_word_id", None) + + if old_major_version > 1: + raise NotImplementedError + return old_data + + @singleton(DOMAIN) async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData: """Set up the pipeline storage collection.""" pipeline_store = PipelineStorageCollection( - Store(hass, STORAGE_VERSION, STORAGE_KEY) + PipelineStore( + hass, STORAGE_VERSION, STORAGE_KEY, minor_version=STORAGE_VERSION_MINOR + ) ) await pipeline_store.async_load() PipelineStorageCollectionWebsocket( diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 8687e2ad40c..1a7362aab80 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -103,6 +103,8 @@ async def test_pipeline_from_audio_stream_legacy( "tts_engine": "test", "tts_language": "en-US", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": None, + "wake_word_id": None, } ) msg = await client.receive_json() @@ -163,6 +165,8 @@ async def test_pipeline_from_audio_stream_entity( "tts_engine": "test", "tts_language": "en-US", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": None, + "wake_word_id": None, } ) msg = await client.receive_json() @@ -223,6 +227,8 @@ async def test_pipeline_from_audio_stream_no_stt( "tts_engine": "test", "tts_language": "en-AU", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": None, + "wake_word_id": None, } ) msg = await client.receive_json() diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 32468e3af91..5a84f4c2716 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -8,15 +8,16 @@ from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.pipeline import ( STORAGE_KEY, STORAGE_VERSION, + STORAGE_VERSION_MINOR, Pipeline, PipelineData, PipelineStorageCollection, + PipelineStore, async_create_default_pipeline, async_get_pipeline, async_get_pipelines, ) from homeassistant.core import HomeAssistant -from homeassistant.helpers.storage import Store from homeassistant.setup import async_setup_component from . import MANY_LANGUAGES @@ -45,6 +46,8 @@ async def test_load_pipelines(hass: HomeAssistant, init_components) -> None: "tts_engine": "tts_engine_1", "tts_language": "language_1", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_1", + "wake_word_id": "wakeword_id_1", }, { "conversation_engine": "conversation_engine_2", @@ -56,6 +59,8 @@ async def test_load_pipelines(hass: HomeAssistant, init_components) -> None: "tts_engine": "tts_engine_2", "tts_language": "language_2", "tts_voice": "The Voice", + "wake_word_entity": "wakeword_entity_2", + "wake_word_id": "wakeword_id_2", }, { "conversation_engine": "conversation_engine_3", @@ -67,6 +72,8 @@ async def test_load_pipelines(hass: HomeAssistant, init_components) -> None: "tts_engine": None, "tts_language": None, "tts_voice": None, + "wake_word_entity": "wakeword_entity_3", + "wake_word_id": "wakeword_id_3", }, ] pipeline_ids = [] @@ -81,7 +88,11 @@ async def test_load_pipelines(hass: HomeAssistant, init_components) -> None: await store1.async_delete_item(pipeline_ids[1]) assert len(store1.data) == 3 - store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY)) + store2 = PipelineStorageCollection( + PipelineStore( + hass, STORAGE_VERSION, STORAGE_KEY, minor_version=STORAGE_VERSION_MINOR + ) + ) await flush_store(store1.store) await store2.async_load() @@ -96,6 +107,71 @@ async def test_loading_pipelines_from_storage( hass: HomeAssistant, hass_storage: dict[str, Any] ) -> None: """Test loading stored pipelines on start.""" + hass_storage[STORAGE_KEY] = { + "version": STORAGE_VERSION, + "minor_version": STORAGE_VERSION_MINOR, + "key": "assist_pipeline.pipelines", + "data": { + "items": [ + { + "conversation_engine": "conversation_engine_1", + "conversation_language": "language_1", + "id": "01GX8ZWBAQYWNB1XV3EXEZ75DY", + "language": "language_1", + "name": "name_1", + "stt_engine": "stt_engine_1", + "stt_language": "language_1", + "tts_engine": "tts_engine_1", + "tts_language": "language_1", + "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_1", + "wake_word_id": "wakeword_id_1", + }, + { + "conversation_engine": "conversation_engine_2", + "conversation_language": "language_2", + "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX", + "language": "language_2", + "name": "name_2", + "stt_engine": "stt_engine_2", + "stt_language": "language_2", + "tts_engine": "tts_engine_2", + "tts_language": "language_2", + "tts_voice": "The Voice", + "wake_word_entity": "wakeword_entity_2", + "wake_word_id": "wakeword_id_2", + }, + { + "conversation_engine": "conversation_engine_3", + "conversation_language": "language_3", + "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J", + "language": "language_3", + "name": "name_3", + "stt_engine": None, + "stt_language": None, + "tts_engine": None, + "tts_language": None, + "tts_voice": None, + "wake_word_entity": "wakeword_entity_3", + "wake_word_id": "wakeword_id_3", + }, + ], + "preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY", + }, + } + + assert await async_setup_component(hass, "assist_pipeline", {}) + + pipeline_data: PipelineData = hass.data[DOMAIN] + store = pipeline_data.pipeline_store + assert len(store.data) == 3 + assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY" + + +async def test_migrate_pipeline_store( + hass: HomeAssistant, hass_storage: dict[str, Any] +) -> None: + """Test loading stored pipelines from an older version.""" hass_storage[STORAGE_KEY] = { "version": 1, "minor_version": 1, @@ -173,6 +249,8 @@ async def test_create_default_pipeline( tts_engine="test", tts_language="en-US", tts_voice="james_earl_jones", + wake_word_entity=None, + wake_word_id=None, ) @@ -213,6 +291,8 @@ async def test_get_pipelines(hass: HomeAssistant) -> None: tts_engine=None, tts_language=None, tts_voice=None, + wake_word_entity=None, + wake_word_id=None, ) ] @@ -258,6 +338,8 @@ async def test_default_pipeline_no_stt_tts( tts_engine=None, tts_language=None, tts_voice=None, + wake_word_entity=None, + wake_word_id=None, ) @@ -318,6 +400,8 @@ async def test_default_pipeline( tts_engine="test", tts_language=tts_language, tts_voice=None, + wake_word_entity=None, + wake_word_id=None, ) @@ -347,6 +431,8 @@ async def test_default_pipeline_unsupported_stt_language( tts_engine="test", tts_language="en-US", tts_voice="james_earl_jones", + wake_word_entity=None, + wake_word_id=None, ) @@ -376,6 +462,8 @@ async def test_default_pipeline_unsupported_tts_language( tts_engine=None, tts_language=None, tts_voice=None, + wake_word_entity=None, + wake_word_id=None, ) @@ -424,4 +512,6 @@ async def test_default_pipeline_cloud( tts_engine="cloud", tts_language="en-US", tts_voice="james_earl_jones", + wake_word_entity=None, + wake_word_id=None, ) diff --git a/tests/components/assist_pipeline/test_select.py b/tests/components/assist_pipeline/test_select.py index 1419eb58750..090c1034e4e 100644 --- a/tests/components/assist_pipeline/test_select.py +++ b/tests/components/assist_pipeline/test_select.py @@ -70,6 +70,8 @@ async def pipeline_1( "tts_voice": None, "stt_engine": None, "stt_language": None, + "wake_word_entity": None, + "wake_word_id": None, } ) @@ -90,6 +92,8 @@ async def pipeline_2( "tts_voice": None, "stt_engine": None, "stt_language": None, + "wake_word_entity": None, + "wake_word_id": None, } ) diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index a7ba9063b3f..a3ca7b62eb4 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -936,6 +936,8 @@ async def test_add_pipeline( "tts_engine": "test_tts_engine", "tts_language": "test_language", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_1", + "wake_word_id": "wakeword_id_1", } ) msg = await client.receive_json() @@ -951,6 +953,8 @@ async def test_add_pipeline( "tts_engine": "test_tts_engine", "tts_language": "test_language", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_1", + "wake_word_id": "wakeword_id_1", } assert len(pipeline_store.data) == 2 @@ -966,6 +970,8 @@ async def test_add_pipeline( tts_engine="test_tts_engine", tts_language="test_language", tts_voice="Arnold Schwarzenegger", + wake_word_entity="wakeword_entity_1", + wake_word_id="wakeword_id_1", ) await client.send_json_auto_id( @@ -1000,6 +1006,8 @@ async def test_add_pipeline_missing_language( "tts_engine": "test_tts_engine", "tts_language": "test_language", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_1", + "wake_word_id": "wakeword_id_1", } ) msg = await client.receive_json() @@ -1018,6 +1026,8 @@ async def test_add_pipeline_missing_language( "tts_engine": "test_tts_engine", "tts_language": None, "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_1", + "wake_word_id": "wakeword_id_1", } ) msg = await client.receive_json() @@ -1045,6 +1055,8 @@ async def test_delete_pipeline( "tts_engine": "test_tts_engine", "tts_language": "test_language", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_1", + "wake_word_id": "wakeword_id_1", } ) msg = await client.receive_json() @@ -1063,6 +1075,8 @@ async def test_delete_pipeline( "tts_engine": "test_tts_engine", "tts_language": "test_language", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_2", + "wake_word_id": "wakeword_id_2", } ) msg = await client.receive_json() @@ -1143,6 +1157,8 @@ async def test_get_pipeline( "tts_engine": "test", "tts_language": "en-US", "tts_voice": "james_earl_jones", + "wake_word_entity": None, + "wake_word_id": None, } await client.send_json_auto_id( @@ -1170,6 +1186,8 @@ async def test_get_pipeline( "tts_engine": "test_tts_engine", "tts_language": "test_language", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_1", + "wake_word_id": "wakeword_id_1", } ) msg = await client.receive_json() @@ -1196,6 +1214,8 @@ async def test_get_pipeline( "tts_engine": "test_tts_engine", "tts_language": "test_language", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_1", + "wake_word_id": "wakeword_id_1", } @@ -1221,6 +1241,8 @@ async def test_list_pipelines( "tts_engine": "test", "tts_language": "en-US", "tts_voice": "james_earl_jones", + "wake_word_entity": None, + "wake_word_id": None, } ], "preferred_pipeline": ANY, @@ -1248,6 +1270,8 @@ async def test_update_pipeline( "tts_engine": "new_tts_engine", "tts_language": "new_tts_language", "tts_voice": "new_tts_voice", + "wake_word_entity": "new_wakeword_entity", + "wake_word_id": "new_wakeword_id", } ) msg = await client.receive_json() @@ -1269,6 +1293,8 @@ async def test_update_pipeline( "tts_engine": "test_tts_engine", "tts_language": "test_language", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_1", + "wake_word_id": "wakeword_id_1", } ) msg = await client.receive_json() @@ -1289,6 +1315,8 @@ async def test_update_pipeline( "tts_engine": "new_tts_engine", "tts_language": "new_tts_language", "tts_voice": "new_tts_voice", + "wake_word_entity": "new_wakeword_entity", + "wake_word_id": "new_wakeword_id", } ) msg = await client.receive_json() @@ -1304,6 +1332,8 @@ async def test_update_pipeline( "tts_engine": "new_tts_engine", "tts_language": "new_tts_language", "tts_voice": "new_tts_voice", + "wake_word_entity": "new_wakeword_entity", + "wake_word_id": "new_wakeword_id", } assert len(pipeline_store.data) == 2 @@ -1319,6 +1349,8 @@ async def test_update_pipeline( tts_engine="new_tts_engine", tts_language="new_tts_language", tts_voice="new_tts_voice", + wake_word_entity="new_wakeword_entity", + wake_word_id="new_wakeword_id", ) await client.send_json_auto_id( @@ -1334,6 +1366,8 @@ async def test_update_pipeline( "tts_engine": None, "tts_language": None, "tts_voice": None, + "wake_word_entity": None, + "wake_word_id": None, } ) msg = await client.receive_json() @@ -1349,6 +1383,8 @@ async def test_update_pipeline( "tts_engine": None, "tts_language": None, "tts_voice": None, + "wake_word_entity": None, + "wake_word_id": None, } pipeline = pipeline_store.data[pipeline_id] @@ -1363,6 +1399,8 @@ async def test_update_pipeline( tts_engine=None, tts_language=None, tts_voice=None, + wake_word_entity=None, + wake_word_id=None, ) @@ -1386,6 +1424,8 @@ async def test_set_preferred_pipeline( "tts_engine": "test_tts_engine", "tts_language": "test_language", "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": "wakeword_entity_1", + "wake_word_id": "wakeword_id_1", } ) msg = await client.receive_json()